chenzihong-gavin commited on
Commit
acd7cf4
·
1 Parent(s): 6505eee
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +6 -0
  2. .gitignore +179 -0
  3. LICENSE +201 -0
  4. graphgen/__init__.py +0 -0
  5. graphgen/configs/config.yaml.example +16 -0
  6. graphgen/configs/graphgen_config.yaml +16 -0
  7. graphgen/evaluate.py +142 -0
  8. graphgen/generate.py +101 -0
  9. graphgen/graphgen.py +260 -0
  10. graphgen/judge.py +60 -0
  11. graphgen/models/__init__.py +41 -0
  12. graphgen/models/embed/__init__.py +0 -0
  13. graphgen/models/embed/embedding.py +29 -0
  14. graphgen/models/evaluate/__init__.py +0 -0
  15. graphgen/models/evaluate/base_evaluator.py +51 -0
  16. graphgen/models/evaluate/length_evaluator.py +22 -0
  17. graphgen/models/evaluate/mtld_evaluator.py +76 -0
  18. graphgen/models/evaluate/reward_evaluator.py +101 -0
  19. graphgen/models/evaluate/uni_evaluator.py +159 -0
  20. graphgen/models/llm/__init__.py +0 -0
  21. graphgen/models/llm/limitter.py +88 -0
  22. graphgen/models/llm/openai_model.py +130 -0
  23. graphgen/models/llm/tokenizer.py +73 -0
  24. graphgen/models/llm/topk_token_model.py +48 -0
  25. graphgen/models/search/__init__.py +0 -0
  26. graphgen/models/search/wiki_search.py +36 -0
  27. graphgen/models/storage/__init__.py +0 -0
  28. graphgen/models/storage/base_storage.py +94 -0
  29. graphgen/models/storage/json_storage.py +51 -0
  30. graphgen/models/storage/networkx_storage.py +159 -0
  31. graphgen/models/strategy/__init__.py +0 -0
  32. graphgen/models/strategy/base_strategy.py +5 -0
  33. graphgen/models/strategy/travserse_strategy.py +30 -0
  34. graphgen/models/text/__init__.py +0 -0
  35. graphgen/models/text/chunk.py +7 -0
  36. graphgen/models/text/text_pair.py +9 -0
  37. graphgen/operators/__init__.py +16 -0
  38. graphgen/operators/extract_kg.py +132 -0
  39. graphgen/operators/judge.py +188 -0
  40. graphgen/operators/merge_kg.py +215 -0
  41. graphgen/operators/quiz.py +109 -0
  42. graphgen/operators/resolute_coreference.py +33 -0
  43. graphgen/operators/search_wikipedia.py +71 -0
  44. graphgen/operators/split_graph.py +333 -0
  45. graphgen/operators/traverse_graph.py +485 -0
  46. graphgen/templates/__init__.py +9 -0
  47. graphgen/templates/answer_rephrasing.py +219 -0
  48. graphgen/templates/coreference_resolution.py +39 -0
  49. graphgen/templates/description_rephrasing.py +121 -0
  50. graphgen/templates/kg_extraction.py +210 -0
.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SYNTHESIZER_MODEL=
2
+ SYNTHESIZER_BASE_URL=
3
+ SYNTHESIZER_API_KEY=
4
+ TRAINEE_MODEL=
5
+ TRAINEE_BASE_URL=
6
+ TRAINEE_API_KEY=
.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ cache
177
+ *.pyc
178
+ *.html
179
+ .gradio
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
graphgen/__init__.py ADDED
File without changes
graphgen/configs/config.yaml.example ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: raw
2
+ input_file: resources/examples/raw_demo.jsonl
3
+ tokenizer: cl100k_base
4
+ quiz_samples: 2
5
+ traverse_strategy:
6
+ qa_form: atomic
7
+ bidirectional: true
8
+ edge_sampling: max_loss
9
+ expand_method: max_tokens
10
+ isolated_node_strategy: add
11
+ max_depth: 2
12
+ max_extra_edges: 5
13
+ max_tokens: 256
14
+ loss_strategy: only_edge
15
+ web_search: false
16
+ re_judge: false
graphgen/configs/graphgen_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: raw
2
+ input_file: resources/examples/raw_demo.jsonl
3
+ tokenizer: cl100k_base
4
+ quiz_samples: 2
5
+ traverse_strategy:
6
+ qa_form: aggregated
7
+ bidirectional: true
8
+ edge_sampling: max_loss
9
+ expand_method: max_width
10
+ isolated_node_strategy: ignore
11
+ max_depth: 1
12
+ max_extra_edges: 2
13
+ max_tokens: 256
14
+ loss_strategy: only_edge
15
+ web_search: false
16
+ re_judge: false
graphgen/evaluate.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate the quality of the generated text using various metrics"""
2
+
3
+ import os
4
+ import json
5
+ import argparse
6
+ import pandas as pd
7
+ from dotenv import load_dotenv
8
+ from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
9
+ from .utils import logger, set_logger
10
+
11
+ sys_path = os.path.abspath(os.path.dirname(__file__))
12
+ set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
13
+
14
+ load_dotenv()
15
+
16
+ def evaluate_length(corpus, tokenizer_name):
17
+ length_evaluator = LengthEvaluator(
18
+ tokenizer_name=tokenizer_name
19
+ )
20
+ logger.info("Length evaluator loaded")
21
+ scores = length_evaluator.get_average_score(corpus)
22
+ logger.info("Length scores: %s", scores)
23
+ return scores
24
+
25
+ def evaluate_mtld(corpus):
26
+ mtld_evaluator = MTLDEvaluator()
27
+ logger.info("MTLD evaluator loaded")
28
+ scores = mtld_evaluator.get_average_score(corpus)
29
+ logger.info("MTLD scores: %s", scores)
30
+ min_max_scores = mtld_evaluator.get_min_max_score(corpus)
31
+ logger.info("MTLD min max scores: %s", min_max_scores)
32
+ return scores, min_max_scores
33
+
34
+ def evaluate_reward(corpus, reward_model_names):
35
+ scores = []
36
+ for reward_name in reward_model_names:
37
+ reward_evaluator = RewardEvaluator(
38
+ reward_name=reward_name
39
+ )
40
+ logger.info("Loaded reward model: %s", reward_name)
41
+ average_score = reward_evaluator.get_average_score(corpus)
42
+ logger.info("%s scores: %s", reward_name, average_score)
43
+ min_max_scores = reward_evaluator.get_min_max_score(corpus)
44
+ logger.info("%s min max scores: %s", reward_name, min_max_scores)
45
+ scores.append({
46
+ 'reward_name': reward_name.split('/')[-1],
47
+ 'score': average_score,
48
+ 'min_max_scores': min_max_scores
49
+ })
50
+ del reward_evaluator
51
+ clean_gpu_cache()
52
+ return scores
53
+
54
+ def evaluate_uni(corpus, uni_model_name):
55
+ uni_evaluator = UniEvaluator(
56
+ model_name=uni_model_name
57
+ )
58
+ logger.info("Uni evaluator loaded with model %s", uni_model_name)
59
+ uni_scores = uni_evaluator.get_average_score(corpus)
60
+ for key, value in uni_scores.items():
61
+ logger.info("Uni %s scores: %s", key, value)
62
+ min_max_scores = uni_evaluator.get_min_max_score(corpus)
63
+ for key, value in min_max_scores.items():
64
+ logger.info("Uni %s min max scores: %s", key, value)
65
+ del uni_evaluator
66
+ clean_gpu_cache()
67
+ return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
68
+ min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
69
+
70
+
71
+ def clean_gpu_cache():
72
+ import torch
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import torch.multiprocessing as mp
79
+ parser = argparse.ArgumentParser()
80
+
81
+ parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
82
+ parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
83
+
84
+ parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
85
+ parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
86
+ help='Comma-separated list of reward models')
87
+ parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
88
+
89
+ args = parser.parse_args()
90
+
91
+ if not os.path.exists(args.folder):
92
+ raise ValueError(f"Folder {args.folder} does not exist")
93
+
94
+ if not os.path.exists(args.output):
95
+ os.makedirs(args.output)
96
+
97
+ reward_models = args.reward.split(',')
98
+
99
+
100
+ results = []
101
+
102
+ logger.info("Data loaded from %s", args.folder)
103
+ mp.set_start_method('spawn')
104
+
105
+ for file in os.listdir(args.folder):
106
+ if file.endswith('.json'):
107
+ logger.info("Processing %s", file)
108
+ with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
109
+ data = json.load(f)
110
+ data = [TextPair(
111
+ question=data[key]['question'],
112
+ answer=data[key]['answer']
113
+ ) for key in data]
114
+
115
+ length_scores = evaluate_length(data, args.tokenizer)
116
+ mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
117
+ reward_scores = evaluate_reward(data, reward_models)
118
+ uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
119
+ min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
120
+ = evaluate_uni(data, args.uni)
121
+
122
+ result = {
123
+ 'file': file,
124
+ 'number': len(data),
125
+ 'length': length_scores,
126
+ 'mtld': mtld_scores,
127
+ 'mtld_min_max': min_max_mtld_scores,
128
+ 'uni_naturalness': uni_naturalness_scores,
129
+ 'uni_coherence': uni_coherence_scores,
130
+ 'uni_understandability': uni_understandability_scores,
131
+ 'uni_naturalness_min_max': min_max_uni_naturalness_scores,
132
+ 'uni_coherence_min_max': min_max_uni_coherence_scores,
133
+ 'uni_understandability_min_max': min_max_uni_understandability_scores
134
+ }
135
+ for reward_score in reward_scores:
136
+ result[reward_score['reward_name']] = reward_score['score']
137
+ result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
138
+
139
+ results.append(result)
140
+
141
+ results = pd.DataFrame(results)
142
+ results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
graphgen/generate.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import argparse
5
+ from importlib.resources import files
6
+ import yaml
7
+ from dotenv import load_dotenv
8
+
9
+ from .graphgen import GraphGen
10
+ from .models import OpenAIModel, Tokenizer, TraverseStrategy
11
+ from .utils import set_logger
12
+
13
+ sys_path = os.path.abspath(os.path.dirname(__file__))
14
+
15
+ load_dotenv()
16
+
17
+ def set_working_dir(folder):
18
+ os.makedirs(folder, exist_ok=True)
19
+ os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
20
+ os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
21
+
22
+ def save_config(config_path, global_config):
23
+ if not os.path.exists(os.path.dirname(config_path)):
24
+ os.makedirs(os.path.dirname(config_path))
25
+ with open(config_path, "w", encoding='utf-8') as config_file:
26
+ yaml.dump(global_config, config_file, default_flow_style=False, allow_unicode=True)
27
+
28
+ def main():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--config_file',
31
+ help='Config parameters for GraphGen.',
32
+ # default=os.path.join(sys_path, "configs", "graphgen_config.yaml"),
33
+ default=files('graphgen').joinpath("configs", "graphgen_config.yaml"),
34
+ type=str)
35
+ parser.add_argument('--output_dir',
36
+ help='Output directory for GraphGen.',
37
+ default=sys_path,
38
+ required=True,
39
+ type=str)
40
+
41
+ args = parser.parse_args()
42
+
43
+ working_dir = args.output_dir
44
+ set_working_dir(working_dir)
45
+ unique_id = int(time.time())
46
+ set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False)
47
+
48
+ with open(args.config_file, "r", encoding='utf-8') as f:
49
+ config = yaml.load(f, Loader=yaml.FullLoader)
50
+
51
+ input_file = config['input_file']
52
+
53
+ if config['data_type'] == 'raw':
54
+ with open(input_file, "r", encoding='utf-8') as f:
55
+ data = [json.loads(line) for line in f]
56
+ elif config['data_type'] == 'chunked':
57
+ with open(input_file, "r", encoding='utf-8') as f:
58
+ data = json.load(f)
59
+ else:
60
+ raise ValueError(f"Invalid data type: {config['data_type']}")
61
+
62
+ synthesizer_llm_client = OpenAIModel(
63
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
64
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
65
+ base_url=os.getenv("SYNTHESIZER_BASE_URL")
66
+ )
67
+ trainee_llm_client = OpenAIModel(
68
+ model_name=os.getenv("TRAINEE_MODEL"),
69
+ api_key=os.getenv("TRAINEE_API_KEY"),
70
+ base_url=os.getenv("TRAINEE_BASE_URL")
71
+ )
72
+
73
+ traverse_strategy = TraverseStrategy(
74
+ **config['traverse_strategy']
75
+ )
76
+
77
+ graph_gen = GraphGen(
78
+ working_dir=working_dir,
79
+ unique_id=unique_id,
80
+ synthesizer_llm_client=synthesizer_llm_client,
81
+ trainee_llm_client=trainee_llm_client,
82
+ if_web_search=config['web_search'],
83
+ tokenizer_instance=Tokenizer(
84
+ model_name=config['tokenizer']
85
+ ),
86
+ traverse_strategy=traverse_strategy
87
+ )
88
+
89
+ graph_gen.insert(data, config['data_type'])
90
+
91
+ graph_gen.quiz(max_samples=config['quiz_samples'])
92
+
93
+ graph_gen.judge(re_judge=config["re_judge"])
94
+
95
+ graph_gen.traverse()
96
+
97
+ path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml")
98
+ save_config(path, config)
99
+
100
+ if __name__ == '__main__':
101
+ main()
graphgen/graphgen.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/HKUDS/LightRAG
2
+
3
+ import asyncio
4
+ import os
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Union, cast
8
+
9
+ import gradio as gr
10
+ from tqdm.asyncio import tqdm as tqdm_async
11
+
12
+ from .models import (
13
+ Chunk,
14
+ JsonKVStorage,
15
+ NetworkXStorage,
16
+ OpenAIModel,
17
+ Tokenizer,
18
+ TraverseStrategy,
19
+ WikiSearch,
20
+ )
21
+ from .models.storage.base_storage import StorageNameSpace
22
+ from .operators import (
23
+ extract_kg,
24
+ judge_statement,
25
+ quiz,
26
+ search_wikipedia,
27
+ skip_judge_statement,
28
+ traverse_graph_atomically,
29
+ traverse_graph_by_edge,
30
+ traverse_graph_for_multi_hop,
31
+ )
32
+ from .utils import compute_content_hash, create_event_loop, logger
33
+
34
+ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
35
+
36
+ @dataclass
37
+ class GraphGen:
38
+ unique_id: int = int(time.time())
39
+ working_dir: str = os.path.join(sys_path, "cache")
40
+
41
+ # text chunking
42
+ chunk_size: int = 1024
43
+ chunk_overlap_size: int = 100
44
+
45
+ # llm
46
+ synthesizer_llm_client: OpenAIModel = None
47
+ trainee_llm_client: OpenAIModel = None
48
+ tokenizer_instance: Tokenizer = None
49
+
50
+ # web search
51
+ if_web_search: bool = False
52
+ wiki_client: WikiSearch = field(default_factory=WikiSearch)
53
+
54
+ # traverse strategy
55
+ traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy)
56
+
57
+ # webui
58
+ progress_bar: gr.Progress = None
59
+
60
+ def __post_init__(self):
61
+ self.full_docs_storage: JsonKVStorage = JsonKVStorage(
62
+ self.working_dir, namespace="full_docs"
63
+ )
64
+ self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
65
+ self.working_dir, namespace="text_chunks"
66
+ )
67
+ self.wiki_storage: JsonKVStorage = JsonKVStorage(
68
+ self.working_dir, namespace="wiki"
69
+ )
70
+ self.graph_storage: NetworkXStorage = NetworkXStorage(
71
+ self.working_dir, namespace="graph"
72
+ )
73
+ self.rephrase_storage: JsonKVStorage = JsonKVStorage(
74
+ self.working_dir, namespace="rephrase"
75
+ )
76
+ self.qa_storage: JsonKVStorage = JsonKVStorage(
77
+ os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}"
78
+ )
79
+
80
+ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict:
81
+ # TODO: 是否进行指代消解
82
+ if len(data) == 0:
83
+ return {}
84
+
85
+ new_docs = {}
86
+ inserting_chunks = {}
87
+ if data_type == "raw":
88
+ assert isinstance(data, list) and isinstance(data[0], dict)
89
+ # compute hash for each document
90
+ new_docs = {
91
+ compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data
92
+ }
93
+ _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
94
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
95
+ if len(new_docs) == 0:
96
+ logger.warning("All docs are already in the storage")
97
+ return {}
98
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
99
+
100
+ cur_index = 1
101
+ doc_number = len(new_docs)
102
+ async for doc_key, doc in tqdm_async(
103
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
104
+ ):
105
+ chunks = {
106
+ compute_content_hash(dp["content"], prefix="chunk-"): {
107
+ **dp,
108
+ 'full_doc_id': doc_key
109
+ } for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"],
110
+ self.chunk_overlap_size, self.chunk_size)
111
+ }
112
+ inserting_chunks.update(chunks)
113
+
114
+ if self.progress_bar is not None:
115
+ self.progress_bar(
116
+ cur_index / doc_number, f"Chunking {doc_key}"
117
+ )
118
+ cur_index += 1
119
+
120
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
121
+ inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
122
+ elif data_type == "chunked":
123
+ assert isinstance(data, list) and isinstance(data[0], list)
124
+ new_docs = {
125
+ compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])}
126
+ for doc in data for chunk in doc
127
+ }
128
+ _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
129
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
130
+ if len(new_docs) == 0:
131
+ logger.warning("All docs are already in the storage")
132
+ return {}
133
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
134
+ async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"):
135
+ doc_str = "".join([chunk['content'] for chunk in doc])
136
+ for chunk in doc:
137
+ chunk_key = compute_content_hash(chunk['content'], prefix="chunk-")
138
+ inserting_chunks[chunk_key] = {
139
+ **chunk,
140
+ 'full_doc_id': compute_content_hash(doc_str, prefix="doc-")
141
+ }
142
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
143
+ inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
144
+
145
+ await self.full_docs_storage.upsert(new_docs)
146
+ await self.text_chunks_storage.upsert(inserting_chunks)
147
+
148
+ return inserting_chunks
149
+
150
+ def insert(self, data: Union[List[list], List[dict]], data_type: str):
151
+ loop = create_event_loop()
152
+ loop.run_until_complete(self.async_insert(data, data_type))
153
+
154
+ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str):
155
+ """
156
+
157
+ insert chunks into the graph
158
+ """
159
+
160
+ inserting_chunks = await self.async_split_chunks(data, data_type)
161
+
162
+ if len(inserting_chunks) == 0:
163
+ logger.warning("All chunks are already in the storage")
164
+ return
165
+ logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
166
+
167
+ logger.info("[Entity and Relation Extraction]...")
168
+ _add_entities_and_relations = await extract_kg(
169
+ llm_client=self.synthesizer_llm_client,
170
+ kg_instance=self.graph_storage,
171
+ tokenizer_instance=self.tokenizer_instance,
172
+ chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()],
173
+ progress_bar = self.progress_bar,
174
+ )
175
+ if not _add_entities_and_relations:
176
+ logger.warning("No entities or relations extracted")
177
+ return
178
+
179
+ logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled')
180
+ if self.if_web_search:
181
+ logger.info("[Wiki Search]...")
182
+ _add_wiki_data = await search_wikipedia(
183
+ llm_client= self.synthesizer_llm_client,
184
+ wiki_search_client=self.wiki_client,
185
+ knowledge_graph_instance=_add_entities_and_relations
186
+ )
187
+ await self.wiki_storage.upsert(_add_wiki_data)
188
+
189
+ await self._insert_done()
190
+
191
+ async def _insert_done(self):
192
+ tasks = []
193
+ for storage_instance in [self.full_docs_storage, self.text_chunks_storage,
194
+ self.graph_storage, self.wiki_storage]:
195
+ if storage_instance is None:
196
+ continue
197
+ tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
198
+ await asyncio.gather(*tasks)
199
+
200
+ def quiz(self, max_samples=1):
201
+ loop = create_event_loop()
202
+ loop.run_until_complete(self.async_quiz(max_samples))
203
+
204
+ async def async_quiz(self, max_samples=1):
205
+ await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
206
+ await self.rephrase_storage.index_done_callback()
207
+
208
+ def judge(self, re_judge=False, skip=False):
209
+ loop = create_event_loop()
210
+ loop.run_until_complete(self.async_judge(re_judge, skip))
211
+
212
+ async def async_judge(self, re_judge=False, skip=False):
213
+ if skip:
214
+ _update_relations = await skip_judge_statement(self.graph_storage)
215
+ else:
216
+ _update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage,
217
+ self.rephrase_storage, re_judge)
218
+ await _update_relations.index_done_callback()
219
+
220
+ def traverse(self):
221
+ loop = create_event_loop()
222
+ loop.run_until_complete(self.async_traverse())
223
+
224
+ async def async_traverse(self):
225
+ if self.traverse_strategy.qa_form == "atomic":
226
+ results = await traverse_graph_atomically(self.synthesizer_llm_client,
227
+ self.tokenizer_instance,
228
+ self.graph_storage,
229
+ self.traverse_strategy,
230
+ self.text_chunks_storage,
231
+ self.progress_bar)
232
+ elif self.traverse_strategy.qa_form == "multi_hop":
233
+ results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
234
+ self.tokenizer_instance,
235
+ self.graph_storage,
236
+ self.traverse_strategy,
237
+ self.text_chunks_storage,
238
+ self.progress_bar)
239
+ elif self.traverse_strategy.qa_form == "aggregated":
240
+ results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
241
+ self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
242
+ self.progress_bar)
243
+ else:
244
+ raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}")
245
+ await self.qa_storage.upsert(results)
246
+ await self.qa_storage.index_done_callback()
247
+
248
+ def clear(self):
249
+ loop = create_event_loop()
250
+ loop.run_until_complete(self.async_clear())
251
+
252
+ async def async_clear(self):
253
+ await self.full_docs_storage.drop()
254
+ await self.text_chunks_storage.drop()
255
+ await self.wiki_storage.drop()
256
+ await self.graph_storage.clear()
257
+ await self.rephrase_storage.drop()
258
+ await self.qa_storage.drop()
259
+
260
+ logger.info("All caches are cleared")
graphgen/judge.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import asyncio
4
+ from dotenv import load_dotenv
5
+
6
+ from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
7
+ from .operators import judge_statement
8
+
9
+ sys_path = os.path.abspath(os.path.dirname(__file__))
10
+
11
+ load_dotenv()
12
+
13
+ def calculate_average_loss(graph: NetworkXStorage):
14
+ """
15
+ Calculate the average loss of the graph.
16
+
17
+ :param graph: NetworkXStorage
18
+ :return: float
19
+ """
20
+ edges = asyncio.run(graph.get_all_edges())
21
+ total_loss = 0
22
+ for edge in edges:
23
+ total_loss += edge[2]['loss']
24
+ return total_loss / len(edges)
25
+
26
+
27
+
28
+ if __name__ == '__main__':
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
31
+ parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
32
+
33
+ args = parser.parse_args()
34
+
35
+ llm_client = OpenAIModel(
36
+ model_name=os.getenv("TRAINEE_MODEL"),
37
+ api_key=os.getenv("TRAINEE_API_KEY"),
38
+ base_url=os.getenv("TRAINEE_BASE_URL")
39
+ )
40
+
41
+ graph_storage = NetworkXStorage(
42
+ args.input,
43
+ namespace="graph"
44
+ )
45
+ average_loss = calculate_average_loss(graph_storage)
46
+ print(f"Average loss of the graph: {average_loss}")
47
+
48
+ rephrase_storage = JsonKVStorage(
49
+ os.path.join(sys_path, "cache"),
50
+ namespace="rephrase"
51
+ )
52
+
53
+ new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
54
+
55
+ graph_file = asyncio.run(graph_storage.get_graph())
56
+
57
+ new_graph.write_nx_graph(graph_file, args.output)
58
+
59
+ average_loss = calculate_average_loss(new_graph)
60
+ print(f"Average loss of the graph: {average_loss}")
graphgen/models/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .text.chunk import Chunk
2
+ from .text.text_pair import TextPair
3
+
4
+ from .llm.topk_token_model import Token, TopkTokenModel
5
+ from .llm.openai_model import OpenAIModel
6
+ from .llm.tokenizer import Tokenizer
7
+
8
+ from .storage.networkx_storage import NetworkXStorage
9
+ from .storage.json_storage import JsonKVStorage
10
+
11
+ from .search.wiki_search import WikiSearch
12
+
13
+ from .evaluate.length_evaluator import LengthEvaluator
14
+ from .evaluate.mtld_evaluator import MTLDEvaluator
15
+ from .evaluate.reward_evaluator import RewardEvaluator
16
+ from .evaluate.uni_evaluator import UniEvaluator
17
+
18
+ from .strategy.travserse_strategy import TraverseStrategy
19
+
20
+
21
+ __all__ = [
22
+ # llm models
23
+ "OpenAIModel",
24
+ "TopkTokenModel",
25
+ "Token",
26
+ "Tokenizer",
27
+ # storage models
28
+ "Chunk",
29
+ "NetworkXStorage",
30
+ "JsonKVStorage",
31
+ # search models
32
+ "WikiSearch",
33
+ # evaluate models
34
+ "TextPair",
35
+ "LengthEvaluator",
36
+ "MTLDEvaluator",
37
+ "RewardEvaluator",
38
+ "UniEvaluator",
39
+ # strategy models
40
+ "TraverseStrategy",
41
+ ]
graphgen/models/embed/__init__.py ADDED
File without changes
graphgen/models/embed/embedding.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import asyncio
3
+ import numpy as np
4
+
5
+ class UnlimitedSemaphore:
6
+ """A context manager that allows unlimited access."""
7
+
8
+ async def __aenter__(self):
9
+ pass
10
+
11
+ async def __aexit__(self, exc_type, exc, tb):
12
+ pass
13
+
14
+ @dataclass
15
+ class EmbeddingFunc:
16
+ embedding_dim: int
17
+ max_token_size: int
18
+ func: callable
19
+ concurrent_limit: int = 16
20
+
21
+ def __post_init__(self):
22
+ if self.concurrent_limit != 0:
23
+ self._semaphore = asyncio.Semaphore(self.concurrent_limit)
24
+ else:
25
+ self._semaphore = UnlimitedSemaphore()
26
+
27
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
28
+ async with self._semaphore:
29
+ return await self.func(*args, **kwargs)
graphgen/models/evaluate/__init__.py ADDED
File without changes
graphgen/models/evaluate/base_evaluator.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from dataclasses import dataclass
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+ from graphgen.utils import create_event_loop
6
+ from graphgen.models.text.text_pair import TextPair
7
+
8
+ @dataclass
9
+ class BaseEvaluator:
10
+ max_concurrent: int = 100
11
+ results: list[float] = None
12
+
13
+ def evaluate(self, pairs: list[TextPair]) -> list[float]:
14
+ """
15
+ Evaluate the text and return a score.
16
+ """
17
+ return create_event_loop().run_until_complete(self.async_evaluate(pairs))
18
+
19
+ async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
20
+ semaphore = asyncio.Semaphore(self.max_concurrent)
21
+
22
+ async def evaluate_with_semaphore(pair):
23
+ async with semaphore: # 获取Semaphore
24
+ return await self.evaluate_single(pair)
25
+
26
+ results = []
27
+ for result in tqdm_async(
28
+ asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
29
+ total=len(pairs),
30
+ ):
31
+ results.append(await result)
32
+ return results
33
+
34
+ async def evaluate_single(self, pair: TextPair) -> float:
35
+ raise NotImplementedError()
36
+
37
+ def get_average_score(self, pairs: list[TextPair]) -> float:
38
+ """
39
+ Get the average score of a batch of texts.
40
+ """
41
+ results = self.evaluate(pairs)
42
+ self.results = results
43
+ return sum(self.results) / len(pairs)
44
+
45
+ def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
46
+ """
47
+ Get the min and max score of a batch of texts.
48
+ """
49
+ if self.results is None:
50
+ self.get_average_score(pairs)
51
+ return min(self.results), max(self.results)
graphgen/models/evaluate/length_evaluator.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from graphgen.models.evaluate.base_evaluator import BaseEvaluator
3
+ from graphgen.models.llm.tokenizer import Tokenizer
4
+ from graphgen.models.text.text_pair import TextPair
5
+ from graphgen.utils import create_event_loop
6
+
7
+
8
+ @dataclass
9
+ class LengthEvaluator(BaseEvaluator):
10
+ tokenizer_name: str = "cl100k_base"
11
+ def __post_init__(self):
12
+ self.tokenizer = Tokenizer(
13
+ model_name=self.tokenizer_name
14
+ )
15
+
16
+ async def evaluate_single(self, pair: TextPair) -> float:
17
+ loop = create_event_loop()
18
+ return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
+
20
+ def _calculate_length(self, text: str) -> float:
21
+ tokens = self.tokenizer.encode_string(text)
22
+ return len(tokens)
graphgen/models/evaluate/mtld_evaluator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Set
3
+
4
+ from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
+ from graphgen.models.text.text_pair import TextPair
6
+ from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
7
+
8
+
9
+ nltk_helper = NLTKHelper()
10
+
11
+ @dataclass
12
+ class MTLDEvaluator(BaseEvaluator):
13
+ """
14
+ 衡量文本词汇多样性的指标
15
+ """
16
+ stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
17
+ stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
18
+
19
+ async def evaluate_single(self, pair: TextPair) -> float:
20
+ loop = create_event_loop()
21
+ return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
22
+
23
+ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
24
+ """
25
+ 计算MTLD (向前和向后的平均值)
26
+
27
+ min is 1.0
28
+ higher is better
29
+ """
30
+ if not text or not text.strip():
31
+ return 0.0
32
+
33
+ lang = detect_main_language(text)
34
+ tokens = nltk_helper.word_tokenize(text, lang)
35
+
36
+ stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
37
+ filtered_tokens = [word for word in tokens if word not in stopwords]
38
+ filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
39
+
40
+ if not filtered_tokens:
41
+ return 0
42
+
43
+ # 计算向前的MTLD
44
+ forward_factors = self._compute_factors(filtered_tokens, threshold)
45
+
46
+ # 计算向后的MTLD
47
+ backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
48
+
49
+ # 取平均值
50
+ return (forward_factors + backward_factors) / 2
51
+
52
+ @staticmethod
53
+ def _compute_factors(tokens: list, threshold: float) -> float:
54
+ factors = 0
55
+ current_segment = []
56
+ unique_words = set()
57
+
58
+ for token in tokens:
59
+ current_segment.append(token)
60
+ unique_words.add(token)
61
+ ttr = len(unique_words) / len(current_segment)
62
+
63
+ if ttr <= threshold:
64
+ factors += 1
65
+ current_segment = []
66
+ unique_words = set()
67
+
68
+ # 处理最后一个不完整片段
69
+ if current_segment:
70
+ ttr = len(unique_words) / len(current_segment)
71
+ if ttr <= threshold:
72
+ factors += 1
73
+ else:
74
+ factors += (1 - (ttr - threshold) / (1 - threshold))
75
+
76
+ return len(tokens) / factors if factors > 0 else len(tokens)
graphgen/models/evaluate/reward_evaluator.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from tqdm import tqdm
3
+ from graphgen.models.text.text_pair import TextPair
4
+
5
+
6
+ @dataclass
7
+ class RewardEvaluator:
8
+ """
9
+ Reward Model Evaluator.
10
+ OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
11
+ """
12
+ reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
13
+ max_length: int = 2560
14
+ results: list[float] = None
15
+
16
+ def __post_init__(self):
17
+ import torch
18
+ self.num_gpus = torch.cuda.device_count()
19
+
20
+ @staticmethod
21
+ def process_chunk(rank, pairs, reward_name, max_length, return_dict):
22
+ import torch
23
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
24
+ device = f'cuda:{rank}'
25
+ torch.cuda.set_device(rank)
26
+
27
+ rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
28
+ tokenizer = AutoTokenizer.from_pretrained(reward_name)
29
+ rank_model.to(device)
30
+ rank_model.eval()
31
+
32
+ results = []
33
+ with torch.no_grad():
34
+ for pair in tqdm(pairs):
35
+ inputs = tokenizer(
36
+ pair.question,
37
+ pair.answer,
38
+ return_tensors="pt",
39
+ max_length=max_length,
40
+ truncation=True
41
+ )
42
+ inputs = {k: v.to(device) for k, v in inputs.items()}
43
+ score = rank_model(**inputs).logits[0].item()
44
+ results.append(score)
45
+
46
+ return_dict[rank] = results
47
+
48
+ def evaluate(self, pairs: list[TextPair]) -> list[float]:
49
+ import torch.multiprocessing as mp
50
+ chunk_size = len(pairs) // self.num_gpus
51
+ chunks = []
52
+ for i in range(self.num_gpus):
53
+ start = i * chunk_size
54
+ end = start + chunk_size
55
+ if i == self.num_gpus - 1:
56
+ end = len(pairs)
57
+ chunks.append(pairs[start:end])
58
+
59
+ # multi-process
60
+ manager = mp.Manager()
61
+ return_dict = manager.dict()
62
+ processes = []
63
+
64
+ for rank, chunk in enumerate(chunks):
65
+ p = mp.Process(
66
+ target=self.process_chunk,
67
+ args=(rank, chunk, self.reward_name, self.max_length, return_dict)
68
+ )
69
+ p.start()
70
+ processes.append(p)
71
+
72
+ for p in processes:
73
+ p.join()
74
+
75
+ # 合并结果
76
+ results = []
77
+ for rank in range(len(chunks)):
78
+ results.extend(return_dict[rank])
79
+
80
+ for p in processes:
81
+ if p.is_alive():
82
+ p.terminate()
83
+ p.join()
84
+
85
+ return results
86
+
87
+ def get_average_score(self, pairs: list[TextPair]) -> float:
88
+ """
89
+ Get the average score of a batch of texts.
90
+ """
91
+ results = self.evaluate(pairs)
92
+ self.results = results
93
+ return sum(self.results) / len(pairs)
94
+
95
+ def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
96
+ """
97
+ Get the min and max score of a batch of texts.
98
+ """
99
+ if self.results is None:
100
+ self.get_average_score(pairs)
101
+ return min(self.results), max(self.results)
graphgen/models/evaluate/uni_evaluator.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/maszhongming/UniEval/tree/main
2
+
3
+ from dataclasses import dataclass, field
4
+ from tqdm import tqdm
5
+ from graphgen.models.text.text_pair import TextPair
6
+
7
+
8
+ def _add_questions(dimension: str, question: str, answer: str):
9
+ if dimension == "naturalness":
10
+ cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + answer
11
+ elif dimension == "coherence":
12
+ cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: ' \
13
+ + answer + ' </s> dialogue history: ' + question
14
+ elif dimension == "understandability":
15
+ cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + answer
16
+ else:
17
+ raise NotImplementedError(
18
+ 'The input format for this dimension is still undefined. Please customize it first.')
19
+ return cur_input
20
+
21
+ @dataclass
22
+ class UniEvaluator:
23
+ model_name: str = "MingZhong/unieval-sum"
24
+ dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability'])
25
+ max_length: int = 2560
26
+ results: dict = None
27
+
28
+ def __post_init__(self):
29
+ import torch
30
+ self.num_gpus = torch.cuda.device_count()
31
+ self.results = {}
32
+
33
+ @staticmethod
34
+ def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
35
+ import torch
36
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
37
+ device = f'cuda:{rank}'
38
+ torch.cuda.set_device(rank)
39
+
40
+ rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ rank_model.to(device)
43
+ rank_model.eval()
44
+
45
+ softmax = torch.nn.Softmax(dim=1)
46
+
47
+ pos_id = tokenizer("Yes")["input_ids"][0]
48
+ neg_id = tokenizer("No")["input_ids"][0]
49
+
50
+ results = []
51
+ with torch.no_grad():
52
+ for pair in tqdm(pairs):
53
+ text = _add_questions(dimension, pair.question, pair.answer)
54
+
55
+ tgt = "No"
56
+
57
+ encoded_src = tokenizer(
58
+ text,
59
+ max_length=max_length,
60
+ truncation=True,
61
+ padding=True,
62
+ return_tensors='pt'
63
+ )
64
+ encoded_tgt = tokenizer(
65
+ tgt,
66
+ max_length=max_length,
67
+ truncation=True,
68
+ padding=True,
69
+ return_tensors='pt'
70
+ )
71
+
72
+ src_tokens = encoded_src['input_ids'].to(device)
73
+ src_mask = encoded_src['attention_mask'].to(device)
74
+
75
+ tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1)
76
+
77
+ output = rank_model(
78
+ input_ids=src_tokens,
79
+ attention_mask=src_mask,
80
+ labels=tgt_tokens,
81
+ use_cache = False
82
+ )
83
+
84
+ logits = output.logits.view(-1, rank_model.config.vocab_size)
85
+
86
+ pos_score = softmax(logits)[:, pos_id] # Yes
87
+ neg_score = softmax(logits)[:, neg_id]
88
+ score = pos_score / (pos_score + neg_score)
89
+
90
+ results.append(score.item())
91
+
92
+ return_dict[rank] = results
93
+
94
+ def evaluate(self, pairs: list[TextPair]) -> list[dict]:
95
+ import torch.multiprocessing as mp
96
+ final_results = []
97
+ for dimension in self.dimensions:
98
+ chunk_size = len(pairs) // self.num_gpus
99
+ chunks = []
100
+ for i in range(self.num_gpus):
101
+ start = i * chunk_size
102
+ end = start + chunk_size
103
+ if i == self.num_gpus - 1:
104
+ end = len(pairs)
105
+ chunks.append(pairs[start:end])
106
+
107
+ # multi-process
108
+ manager = mp.Manager()
109
+ return_dict = manager.dict()
110
+ processes = []
111
+
112
+ for rank, chunk in enumerate(chunks):
113
+ p = mp.Process(
114
+ target=self.process_chunk,
115
+ args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict)
116
+ )
117
+ p.start()
118
+ processes.append(p)
119
+
120
+ for p in processes:
121
+ p.join()
122
+
123
+ # 合并结果
124
+ results = []
125
+ for rank in range(len(chunks)):
126
+ results.extend(return_dict[rank])
127
+
128
+ for p in processes:
129
+ if p.is_alive():
130
+ p.terminate()
131
+ p.join()
132
+
133
+ final_results.append({
134
+ dimension: results
135
+ })
136
+ return final_results
137
+
138
+ def get_average_score(self, pairs: list[TextPair]) -> dict:
139
+ """
140
+ Get the average score of a batch of texts.
141
+ """
142
+ results = self.evaluate(pairs)
143
+ final_results = {}
144
+ for result in results:
145
+ for key, value in result.items():
146
+ final_results[key] = sum(value) / len(value)
147
+ self.results[key] = value
148
+ return final_results
149
+
150
+ def get_min_max_score(self, pairs: list[TextPair]) -> dict:
151
+ """
152
+ Get the min and max score of a batch of texts.
153
+ """
154
+ if self.results is None:
155
+ self.get_average_score(pairs)
156
+ final_results = {}
157
+ for key, value in self.results.items():
158
+ final_results[key] = min(value), max(value)
159
+ return final_results
graphgen/models/llm/__init__.py ADDED
File without changes
graphgen/models/llm/limitter.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from datetime import datetime, timedelta
3
+ import asyncio
4
+
5
+ from graphgen.utils import logger
6
+
7
+
8
+ class RPM:
9
+
10
+ def __init__(self, rpm: int = 1000):
11
+ self.rpm = rpm
12
+ self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
13
+
14
+ def get_minute_slot(self):
15
+ current_time = time.time()
16
+ dt_object = datetime.fromtimestamp(current_time)
17
+ total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
18
+ return total_minutes_since_midnight
19
+
20
+ async def wait(self, silent=False):
21
+ current = time.time()
22
+ dt_object = datetime.fromtimestamp(current)
23
+ minute_slot = self.get_minute_slot()
24
+
25
+ if self.record['rpm_slot'] == minute_slot:
26
+ # check RPM exceed
27
+ if self.record['counter'] >= self.rpm:
28
+ # wait until next minute
29
+ next_minute = dt_object.replace(
30
+ second=0, microsecond=0) + timedelta(minutes=1)
31
+ _next = next_minute.timestamp()
32
+ sleep_time = abs(_next - current)
33
+ if not silent:
34
+ logger.info('RPM sleep %s', sleep_time)
35
+ await asyncio.sleep(sleep_time)
36
+
37
+ self.record = {
38
+ 'rpm_slot': self.get_minute_slot(),
39
+ 'counter': 0
40
+ }
41
+ else:
42
+ self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
43
+ self.record['counter'] += 1
44
+
45
+ if not silent:
46
+ logger.debug(self.record)
47
+
48
+
49
+ class TPM:
50
+
51
+ def __init__(self, tpm: int = 20000):
52
+ self.tpm = tpm
53
+ self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
54
+
55
+ def get_minute_slot(self):
56
+ current_time = time.time()
57
+ dt_object = datetime.fromtimestamp(current_time)
58
+ total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
59
+ return total_minutes_since_midnight
60
+
61
+ async def wait(self, token_count, silent=False):
62
+ current = time.time()
63
+ dt_object = datetime.fromtimestamp(current)
64
+ minute_slot = self.get_minute_slot()
65
+
66
+ # get next slot, skip
67
+ if self.record['tpm_slot'] != minute_slot:
68
+ self.record = {'tpm_slot': minute_slot, 'counter': token_count}
69
+ return
70
+
71
+ # check RPM exceed
72
+ self.record['counter'] += token_count
73
+ if self.record['counter'] > self.tpm:
74
+ # wait until next minute
75
+ next_minute = dt_object.replace(
76
+ second=0, microsecond=0) + timedelta(minutes=1)
77
+ _next = next_minute.timestamp()
78
+ sleep_time = abs(_next - current)
79
+ logger.info('TPM sleep %s', sleep_time)
80
+ await asyncio.sleep(sleep_time)
81
+
82
+ self.record = {
83
+ 'tpm_slot': self.get_minute_slot(),
84
+ 'counter': token_count
85
+ }
86
+
87
+ if not silent:
88
+ logger.debug(self.record)
graphgen/models/llm/openai_model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Dict, Optional
4
+ import openai
5
+ from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
6
+ from tenacity import (
7
+ retry,
8
+ stop_after_attempt,
9
+ wait_exponential,
10
+ retry_if_exception_type,
11
+ )
12
+
13
+ from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
14
+ from graphgen.models.llm.tokenizer import Tokenizer
15
+ from graphgen.models.llm.limitter import RPM, TPM
16
+
17
+ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
18
+ token_logprobs = response.choices[0].logprobs.content
19
+ tokens = []
20
+ for token_prob in token_logprobs:
21
+ prob = math.exp(token_prob.logprob)
22
+ candidate_tokens = [
23
+ Token(t.token, math.exp(t.logprob))
24
+ for t in token_prob.top_logprobs
25
+ ]
26
+ token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
27
+ tokens.append(token)
28
+ return tokens
29
+
30
+ @dataclass
31
+ class OpenAIModel(TopkTokenModel):
32
+ model_name: str = "gpt-4o-mini"
33
+ api_key: str = None
34
+ base_url: str = None
35
+
36
+ system_prompt: str = ""
37
+ json_mode: bool = False
38
+ seed: int = None
39
+
40
+ token_usage: list = field(default_factory=list)
41
+ request_limit: bool = False
42
+ rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
43
+ tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
44
+
45
+
46
+ def __post_init__(self):
47
+ assert self.api_key is not None, "Please provide api key to access openai api."
48
+ if self.api_key == "":
49
+ self.api_key = "none"
50
+ self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
51
+
52
+ def _pre_generate(self, text: str, history: List[str]) -> Dict:
53
+ kwargs = {
54
+ "temperature": self.temperature,
55
+ "top_p": self.topp,
56
+ "max_tokens": self.max_tokens,
57
+ }
58
+ if self.seed:
59
+ kwargs["seed"] = self.seed
60
+ if self.json_mode:
61
+ kwargs["response_format"] = {"type": "json_object"}
62
+
63
+ messages = []
64
+ if self.system_prompt:
65
+ messages.append({"role": "system", "content": self.system_prompt})
66
+ messages.append({"role": "user", "content": text})
67
+
68
+ if history:
69
+ assert len(history) % 2 == 0, "History should have even number of elements."
70
+ messages = history + messages
71
+
72
+ kwargs['messages']= messages
73
+ return kwargs
74
+
75
+
76
+ @retry(
77
+ stop=stop_after_attempt(5),
78
+ wait=wait_exponential(multiplier=1, min=4, max=10),
79
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
80
+ )
81
+ async def generate_topk_per_token(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
82
+ kwargs = self._pre_generate(text, history)
83
+ if self.topk_per_token > 0:
84
+ kwargs["logprobs"] = True
85
+ kwargs["top_logprobs"] = self.topk_per_token
86
+
87
+ # Limit max_tokens to 1 to avoid long completions
88
+ kwargs["max_tokens"] = 1
89
+
90
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
91
+ model=self.model_name,
92
+ **kwargs
93
+ )
94
+
95
+ tokens = get_top_response_tokens(completion)
96
+
97
+ return tokens
98
+
99
+ @retry(
100
+ stop=stop_after_attempt(5),
101
+ wait=wait_exponential(multiplier=1, min=4, max=10),
102
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
103
+ )
104
+ async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
105
+ kwargs = self._pre_generate(text, history)
106
+ kwargs["temperature"] = temperature
107
+
108
+ prompt_tokens = 0
109
+ for message in kwargs['messages']:
110
+ prompt_tokens += len(Tokenizer().encode_string(message['content']))
111
+ estimated_tokens = prompt_tokens + kwargs['max_tokens']
112
+
113
+ if self.request_limit:
114
+ await self.rpm.wait(silent=True)
115
+ await self.tpm.wait(estimated_tokens, silent=True)
116
+
117
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
118
+ model=self.model_name,
119
+ **kwargs
120
+ )
121
+ if hasattr(completion, "usage"):
122
+ self.token_usage.append({
123
+ "prompt_tokens": completion.usage.prompt_tokens,
124
+ "completion_tokens": completion.usage.completion_tokens,
125
+ "total_tokens": completion.usage.total_tokens,
126
+ })
127
+ return completion.choices[0].message.content
128
+
129
+ async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
130
+ raise NotImplementedError
graphgen/models/llm/tokenizer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import tiktoken
4
+
5
+ try:
6
+ from transformers import AutoTokenizer
7
+ TRANSFORMERS_AVAILABLE = True
8
+ except ImportError:
9
+ AutoTokenizer = None
10
+ TRANSFORMERS_AVAILABLE = False
11
+
12
+
13
+ def get_tokenizer(tokenizer_name: str = "cl100k_base"):
14
+ """
15
+ Get a tokenizer instance by name.
16
+
17
+ :param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
18
+ :return: tokenizer instance
19
+ """
20
+ if tokenizer_name in tiktoken.list_encoding_names():
21
+ return tiktoken.get_encoding(tokenizer_name)
22
+ if TRANSFORMERS_AVAILABLE:
23
+ try:
24
+ return AutoTokenizer.from_pretrained(tokenizer_name)
25
+ except Exception as e:
26
+ raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
27
+ else:
28
+ raise ValueError("Hugging Face Transformers is not available, please install it first.")
29
+
30
+ @dataclass
31
+ class Tokenizer:
32
+ model_name: str = "cl100k_base"
33
+
34
+ def __post_init__(self):
35
+ self.tokenizer = get_tokenizer(self.model_name)
36
+
37
+ def encode_string(self, text: str) -> List[int]:
38
+ """
39
+ Encode text to tokens
40
+
41
+ :param text
42
+ :return: tokens
43
+ """
44
+ return self.tokenizer.encode(text)
45
+
46
+ def decode_tokens(self, tokens: List[int]) -> str:
47
+ """
48
+ Decode tokens to text
49
+
50
+ :param tokens
51
+ :return: text
52
+ """
53
+ return self.tokenizer.decode(tokens)
54
+
55
+ def chunk_by_token_size(
56
+ self, content: str, overlap_token_size=128, max_token_size=1024
57
+ ):
58
+ tokens = self.encode_string(content)
59
+ results = []
60
+ for index, start in enumerate(
61
+ range(0, len(tokens), max_token_size - overlap_token_size)
62
+ ):
63
+ chunk_content = self.decode_tokens(
64
+ tokens[start : start + max_token_size]
65
+ )
66
+ results.append(
67
+ {
68
+ "tokens": min(max_token_size, len(tokens) - start),
69
+ "content": chunk_content.strip(),
70
+ "chunk_order_index": index,
71
+ }
72
+ )
73
+ return results
graphgen/models/llm/topk_token_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Union, Optional
4
+
5
+
6
+ @dataclass
7
+ class Token:
8
+ text: str
9
+ prob: float
10
+ top_candidates: List = field(default_factory=list)
11
+ ppl: Union[float, None] = field(default=None)
12
+
13
+ @property
14
+ def logprob(self) -> float:
15
+ return math.log(self.prob)
16
+
17
+
18
+ @dataclass
19
+ class TopkTokenModel:
20
+ do_sample: bool = False
21
+ temperature: float = 0
22
+ max_tokens: int = 4096
23
+ repetition_penalty: float = 1.05
24
+ num_beams: int = 1
25
+ topk: int = 50
26
+ topp: float = 0.95
27
+
28
+ topk_per_token: int = 5 # number of topk tokens to generate for each token
29
+
30
+ async def generate_topk_per_token(self, text: str) -> List[Token]:
31
+ """
32
+ Generate prob, text and candidates for each token of the model's output.
33
+ This function is used to visualize the inference process.
34
+ """
35
+ raise NotImplementedError
36
+
37
+ async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
38
+ """
39
+ Generate prob and text for each token of the input text.
40
+ This function is used to visualize the ppl.
41
+ """
42
+ raise NotImplementedError
43
+
44
+ async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
45
+ """
46
+ Generate answer from the model.
47
+ """
48
+ raise NotImplementedError
graphgen/models/search/__init__.py ADDED
File without changes
graphgen/models/search/wiki_search.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from dataclasses import dataclass
3
+
4
+ import wikipedia
5
+ from wikipedia import set_lang
6
+ from graphgen.utils import detect_main_language, logger
7
+
8
+
9
+ @dataclass
10
+ class WikiSearch:
11
+ @staticmethod
12
+ def set_language(language: str):
13
+ assert language in ["en", "zh"], "Only support English and Chinese"
14
+ set_lang(language)
15
+
16
+ async def search(self, query: str) -> Union[List[str], None]:
17
+ self.set_language(detect_main_language(query))
18
+ return wikipedia.search(query)
19
+
20
+ async def summary(self, query: str) -> Union[str, None]:
21
+ self.set_language(detect_main_language(query))
22
+ try:
23
+ result = wikipedia.summary(query, auto_suggest=False, redirect=False)
24
+ except wikipedia.exceptions.DisambiguationError as e:
25
+ logger.error("DisambiguationError: %s", e)
26
+ result = None
27
+ return result
28
+
29
+ async def page(self, query: str) -> Union[str, None]:
30
+ self.set_language(detect_main_language(query))
31
+ try:
32
+ result = wikipedia.page(query, auto_suggest=False, redirect=False).content
33
+ except wikipedia.exceptions.DisambiguationError as e:
34
+ logger.error("DisambiguationError: %s", e)
35
+ result = None
36
+ return result
graphgen/models/storage/__init__.py ADDED
File without changes
graphgen/models/storage/base_storage.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union, Generic, TypeVar
3
+ from graphgen.models.embed.embedding import EmbeddingFunc
4
+
5
+ T = TypeVar("T")
6
+
7
+ @dataclass
8
+ class StorageNameSpace:
9
+ working_dir: str = None
10
+ namespace: str = None
11
+
12
+ async def index_done_callback(self):
13
+ """commit the storage operations after indexing"""
14
+
15
+ async def query_done_callback(self):
16
+ """commit the storage operations after querying"""
17
+
18
+
19
+ @dataclass
20
+ class BaseKVStorage(Generic[T], StorageNameSpace):
21
+ embedding_func: EmbeddingFunc = None
22
+
23
+ async def all_keys(self) -> list[str]:
24
+ raise NotImplementedError
25
+
26
+ async def get_by_id(self, id: str) -> Union[T, None]:
27
+ raise NotImplementedError
28
+
29
+ async def get_by_ids(
30
+ self, ids: list[str], fields: Union[set[str], None] = None
31
+ ) -> list[Union[T, None]]:
32
+ raise NotImplementedError
33
+
34
+ async def filter_keys(self, data: list[str]) -> set[str]:
35
+ """return un-exist keys"""
36
+ raise NotImplementedError
37
+
38
+ async def upsert(self, data: dict[str, T]):
39
+ raise NotImplementedError
40
+
41
+ async def drop(self):
42
+ raise NotImplementedError
43
+
44
+ @dataclass
45
+ class BaseGraphStorage(StorageNameSpace):
46
+ embedding_func: EmbeddingFunc = None
47
+
48
+ async def has_node(self, node_id: str) -> bool:
49
+ raise NotImplementedError
50
+
51
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
52
+ raise NotImplementedError
53
+
54
+ async def node_degree(self, node_id: str) -> int:
55
+ raise NotImplementedError
56
+
57
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
58
+ raise NotImplementedError
59
+
60
+ async def get_node(self, node_id: str) -> Union[dict, None]:
61
+ raise NotImplementedError
62
+
63
+ async def update_node(self, node_id: str, node_data: dict[str, str]):
64
+ raise NotImplementedError
65
+
66
+ async def get_all_nodes(self) -> Union[list[dict], None]:
67
+ raise NotImplementedError
68
+
69
+ async def get_edge(
70
+ self, source_node_id: str, target_node_id: str
71
+ ) -> Union[dict, None]:
72
+ raise NotImplementedError
73
+
74
+ async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
75
+ raise NotImplementedError
76
+
77
+ async def get_all_edges(self) -> Union[list[dict], None]:
78
+ raise NotImplementedError
79
+
80
+ async def get_node_edges(
81
+ self, source_node_id: str
82
+ ) -> Union[list[tuple[str, str]], None]:
83
+ raise NotImplementedError
84
+
85
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
86
+ raise NotImplementedError
87
+
88
+ async def upsert_edge(
89
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
90
+ ):
91
+ raise NotImplementedError
92
+
93
+ async def delete_node(self, node_id: str):
94
+ raise NotImplementedError
graphgen/models/storage/json_storage.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dataclasses import dataclass
4
+ from graphgen.utils import logger, load_json, write_json
5
+ from graphgen.models.storage.base_storage import BaseKVStorage
6
+
7
+
8
+ @dataclass
9
+ class JsonKVStorage(BaseKVStorage):
10
+ _data: dict[str, str] = None
11
+
12
+ def __post_init__(self):
13
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
14
+ self._data = load_json(self._file_name) or {}
15
+ logger.info("Load KV %s with %d data", self.namespace, len(self._data))
16
+
17
+ @property
18
+ def data(self):
19
+ return self._data
20
+
21
+ async def all_keys(self) -> list[str]:
22
+ return list(self._data.keys())
23
+
24
+ async def index_done_callback(self):
25
+ write_json(self._data, self._file_name)
26
+
27
+ async def get_by_id(self, id):
28
+ return self._data.get(id, None)
29
+
30
+ async def get_by_ids(self, ids, fields=None) -> list:
31
+ if fields is None:
32
+ return [self._data.get(id, None) for id in ids]
33
+ return [
34
+ (
35
+ {k: v for k, v in self._data[id].items() if k in fields}
36
+ if self._data.get(id, None)
37
+ else None
38
+ )
39
+ for id in ids
40
+ ]
41
+
42
+ async def filter_keys(self, data: list[str]) -> set[str]:
43
+ return {s for s in data if s not in self._data}
44
+
45
+ async def upsert(self, data: dict):
46
+ left_data = {k: v for k, v in data.items() if k not in self._data}
47
+ self._data.update(left_data)
48
+ return left_data
49
+
50
+ async def drop(self):
51
+ self._data = {}
graphgen/models/storage/networkx_storage.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import html
3
+ from typing import Any, Union, cast, Optional
4
+ from dataclasses import dataclass
5
+ import networkx as nx
6
+
7
+ from graphgen.utils import logger
8
+ from .base_storage import BaseGraphStorage
9
+
10
+ @dataclass
11
+ class NetworkXStorage(BaseGraphStorage):
12
+ @staticmethod
13
+ def load_nx_graph(file_name) -> Optional[nx.Graph]:
14
+ if os.path.exists(file_name):
15
+ return nx.read_graphml(file_name)
16
+ return None
17
+
18
+ @staticmethod
19
+ def write_nx_graph(graph: nx.Graph, file_name):
20
+ logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges())
21
+ nx.write_graphml(graph, file_name)
22
+
23
+ @staticmethod
24
+ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
25
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
26
+ Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
27
+ """
28
+ from graspologic.utils import largest_connected_component
29
+
30
+ graph = graph.copy()
31
+ graph = cast(nx.Graph, largest_connected_component(graph))
32
+ node_mapping = {
33
+ node: html.unescape(node.upper().strip()) for node in graph.nodes()
34
+ } # type: ignore
35
+ graph = nx.relabel_nodes(graph, node_mapping)
36
+ return NetworkXStorage._stabilize_graph(graph)
37
+
38
+ @staticmethod
39
+ def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
40
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
41
+ Ensure an undirected graph with the same relationships will always be read the same way.
42
+ 通过对节点和边进行排序来实现
43
+ """
44
+ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
45
+
46
+ sorted_nodes = graph.nodes(data=True)
47
+ sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
48
+
49
+ fixed_graph.add_nodes_from(sorted_nodes)
50
+ edges = list(graph.edges(data=True))
51
+
52
+ if not graph.is_directed():
53
+
54
+ def _sort_source_target(edge):
55
+ source, target, edge_data = edge
56
+ if source > target:
57
+ source, target = target, source
58
+ return source, target, edge_data
59
+
60
+ edges = [_sort_source_target(edge) for edge in edges]
61
+
62
+ def _get_edge_key(source: Any, target: Any) -> str:
63
+ return f"{source} -> {target}"
64
+
65
+ edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
66
+
67
+ fixed_graph.add_edges_from(edges)
68
+ return fixed_graph
69
+
70
+ def __post_init__(self):
71
+ """
72
+ 如果图文件存在,则加载图文件,否则创建一个新图
73
+ """
74
+ self._graphml_xml_file = os.path.join(
75
+ self.working_dir, f"{self.namespace}.graphml"
76
+ )
77
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
78
+ if preloaded_graph is not None:
79
+ logger.info(
80
+ "Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file,
81
+ preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges()
82
+ )
83
+ self._graph = preloaded_graph or nx.Graph()
84
+
85
+ async def index_done_callback(self):
86
+ NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
87
+
88
+ async def has_node(self, node_id: str) -> bool:
89
+ return self._graph.has_node(node_id)
90
+
91
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
92
+ return self._graph.has_edge(source_node_id, target_node_id)
93
+
94
+ async def get_node(self, node_id: str) -> Union[dict, None]:
95
+ return self._graph.nodes.get(node_id)
96
+
97
+ async def get_all_nodes(self) -> Union[list[dict], None]:
98
+ return self._graph.nodes(data=True)
99
+
100
+ async def node_degree(self, node_id: str) -> int:
101
+ return self._graph.degree(node_id)
102
+
103
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
104
+ return self._graph.degree(src_id) + self._graph.degree(tgt_id)
105
+
106
+ async def get_edge(
107
+ self, source_node_id: str, target_node_id: str
108
+ ) -> Union[dict, None]:
109
+ return self._graph.edges.get((source_node_id, target_node_id))
110
+
111
+ async def get_all_edges(self) -> Union[list[dict], None]:
112
+ return self._graph.edges(data=True)
113
+
114
+ async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
115
+ if self._graph.has_node(source_node_id):
116
+ return list(self._graph.edges(source_node_id, data=True))
117
+ return None
118
+
119
+ async def get_graph(self) -> nx.Graph:
120
+ return self._graph
121
+
122
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
123
+ self._graph.add_node(node_id, **node_data)
124
+
125
+ async def update_node(self, node_id: str, node_data: dict[str, str]):
126
+ if self._graph.has_node(node_id):
127
+ self._graph.nodes[node_id].update(node_data)
128
+ else:
129
+ logger.warning("Node %s not found in the graph for update.", node_id)
130
+
131
+ async def upsert_edge(
132
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
133
+ ):
134
+ self._graph.add_edge(source_node_id, target_node_id, **edge_data)
135
+
136
+ async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
137
+ if self._graph.has_edge(source_node_id, target_node_id):
138
+ self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
139
+ else:
140
+ logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id)
141
+
142
+ async def delete_node(self, node_id: str):
143
+ """
144
+ Delete a node from the graph based on the specified node_id.
145
+
146
+ :param node_id: The node_id to delete
147
+ """
148
+ if self._graph.has_node(node_id):
149
+ self._graph.remove_node(node_id)
150
+ logger.info("Node %s deleted from the graph.", node_id)
151
+ else:
152
+ logger.warning("Node %s not found in the graph for deletion.", node_id)
153
+
154
+ async def clear(self):
155
+ """
156
+ Clear the graph by removing all nodes and edges.
157
+ """
158
+ self._graph.clear()
159
+ logger.info("Graph %s cleared.", self.namespace)
graphgen/models/strategy/__init__.py ADDED
File without changes
graphgen/models/strategy/base_strategy.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class BaseStrategy:
5
+ pass
graphgen/models/strategy/travserse_strategy.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+
3
+ from graphgen.models.strategy.base_strategy import BaseStrategy
4
+
5
+
6
+ @dataclass
7
+ class TraverseStrategy(BaseStrategy):
8
+ # 生成的QA形式:原子、多跳、聚合型
9
+ qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
10
+ # 最大边数和最大token数方法中选择一个生效
11
+ expand_method: str = "max_tokens" # "max_width" or "max_tokens"
12
+ # 单向拓展还是双向拓展
13
+ bidirectional: bool = True
14
+ # 每个方向拓展的最大边数
15
+ max_extra_edges: int = 5
16
+ # 最长token数
17
+ max_tokens: int = 256
18
+ # 每个方向拓展的最大深度
19
+ max_depth: int = 2
20
+ # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
21
+ edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
22
+ # 孤立节点的处理策略
23
+ isolated_node_strategy: str = "add" # "add" or "ignore"
24
+ loss_strategy: str = "only_edge" # only_edge, both
25
+
26
+ def to_yaml(self):
27
+ strategy_dict = {}
28
+ for f in fields(self):
29
+ strategy_dict[f.name] = getattr(self, f.name)
30
+ return {"traverse_strategy": strategy_dict}
graphgen/models/text/__init__.py ADDED
File without changes
graphgen/models/text/chunk.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class Chunk:
6
+ id : str
7
+ content: str
graphgen/models/text/text_pair.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class TextPair:
5
+ """
6
+ A pair of input data.
7
+ """
8
+ question: str
9
+ answer: str
graphgen/operators/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .extract_kg import extract_kg
2
+ from .quiz import quiz
3
+ from .judge import judge_statement, skip_judge_statement
4
+ from .search_wikipedia import search_wikipedia
5
+ from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically, traverse_graph_for_multi_hop
6
+
7
+ __all__ = [
8
+ "extract_kg",
9
+ "quiz",
10
+ "judge_statement",
11
+ "skip_judge_statement",
12
+ "search_wikipedia",
13
+ "traverse_graph_by_edge",
14
+ "traverse_graph_atomically",
15
+ "traverse_graph_for_multi_hop"
16
+ ]
graphgen/operators/extract_kg.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import asyncio
3
+ from typing import List
4
+ from collections import defaultdict
5
+
6
+ import gradio as gr
7
+ from tqdm.asyncio import tqdm as tqdm_async
8
+ from graphgen.models import Chunk, OpenAIModel, Tokenizer
9
+ from graphgen.models.storage.base_storage import BaseGraphStorage
10
+ from graphgen.templates import KG_EXTRACTION_PROMPT
11
+ from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
12
+ handle_single_entity_extraction, handle_single_relationship_extraction,
13
+ detect_if_chinese)
14
+ from graphgen.operators.merge_kg import merge_nodes, merge_edges
15
+
16
+
17
+ # pylint: disable=too-many-statements
18
+ async def extract_kg(
19
+ llm_client: OpenAIModel,
20
+ kg_instance: BaseGraphStorage,
21
+ tokenizer_instance: Tokenizer,
22
+ chunks: List[Chunk],
23
+ progress_bar: gr.Progress = None,
24
+ max_concurrent: int = 1000
25
+ ):
26
+ """
27
+ :param llm_client: Synthesizer LLM model to extract entities and relationships
28
+ :param kg_instance
29
+ :param tokenizer_instance
30
+ :param chunks
31
+ :param progress_bar: Gradio progress bar to show the progress of the extraction
32
+ :param max_concurrent
33
+ :return:
34
+ """
35
+
36
+ semaphore = asyncio.Semaphore(max_concurrent)
37
+
38
+ async def _process_single_content(chunk: Chunk, max_loop: int = 3):
39
+ async with semaphore:
40
+ chunk_id = chunk.id
41
+ content = chunk.content
42
+ if detect_if_chinese(content):
43
+ language = "Chinese"
44
+ else:
45
+ language = "English"
46
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
47
+
48
+ hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
49
+ **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
50
+ )
51
+
52
+ final_result = await llm_client.generate_answer(hint_prompt)
53
+ logger.info('First result: %s', final_result)
54
+
55
+ history = pack_history_conversations(hint_prompt, final_result)
56
+ for loop_index in range(max_loop):
57
+ if_loop_result = await llm_client.generate_answer(
58
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
59
+ history=history
60
+ )
61
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
62
+ if if_loop_result != "yes":
63
+ break
64
+
65
+ glean_result = await llm_client.generate_answer(
66
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
67
+ history=history
68
+ )
69
+ logger.info('Loop %s glean: %s', loop_index, glean_result)
70
+
71
+ history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
72
+ final_result += glean_result
73
+ if loop_index == max_loop - 1:
74
+ break
75
+
76
+ records = split_string_by_multi_markers(
77
+ final_result,
78
+ [
79
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
80
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
81
+ )
82
+
83
+ nodes = defaultdict(list)
84
+ edges = defaultdict(list)
85
+
86
+ for record in records:
87
+ record = re.search(r"\((.*)\)", record)
88
+ if record is None:
89
+ continue
90
+ record = record.group(1) # 提取括号内的内容
91
+ record_attributes = split_string_by_multi_markers(
92
+ record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
93
+ )
94
+
95
+ entity = await handle_single_entity_extraction(record_attributes, chunk_id)
96
+ if entity is not None:
97
+ nodes[entity["entity_name"]].append(entity)
98
+ continue
99
+ relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
100
+ if relation is not None:
101
+ edges[(relation["src_id"], relation["tgt_id"])].append(relation)
102
+ return dict(nodes), dict(edges)
103
+
104
+ results = []
105
+ chunk_number = len(chunks)
106
+ async for result in tqdm_async(
107
+ asyncio.as_completed([_process_single_content(c) for c in chunks]),
108
+ total=len(chunks),
109
+ desc="[3/4]Extracting entities and relationships from chunks",
110
+ unit="chunk",
111
+ ):
112
+ try:
113
+ if progress_bar is not None:
114
+ progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
115
+ results.append(await result)
116
+ if progress_bar is not None and len(results) == chunk_number:
117
+ progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
118
+ except Exception as e: # pylint: disable=broad-except
119
+ logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
120
+
121
+ nodes = defaultdict(list)
122
+ edges = defaultdict(list)
123
+ for n, e in results:
124
+ for k, v in n.items():
125
+ nodes[k].extend(v)
126
+ for k, v in e.items():
127
+ edges[tuple(sorted(k))].extend(v)
128
+
129
+ await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
130
+ await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
131
+
132
+ return kg_instance
graphgen/operators/judge.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import asyncio
3
+ from tqdm.asyncio import tqdm as tqdm_async
4
+ from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage
5
+ from graphgen.utils import logger, yes_no_loss_entropy
6
+ from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
7
+
8
+
9
+ async def judge_statement( # pylint: disable=too-many-statements
10
+ trainee_llm_client: OpenAIModel,
11
+ graph_storage: NetworkXStorage,
12
+ rephrase_storage: JsonKVStorage,
13
+ re_judge: bool = False,
14
+ max_concurrent: int = 1000) -> NetworkXStorage:
15
+ """
16
+ Get all edges and nodes and judge them
17
+
18
+ :param trainee_llm_client: judge the statements to get comprehension loss
19
+ :param graph_storage: graph storage instance
20
+ :param rephrase_storage: rephrase storage instance
21
+ :param re_judge: re-judge the relations
22
+ :param max_concurrent: max concurrent
23
+ :return:
24
+ """
25
+
26
+ semaphore = asyncio.Semaphore(max_concurrent)
27
+
28
+ async def _judge_single_relation(
29
+ edge: tuple,
30
+ ):
31
+ async with semaphore:
32
+ source_id = edge[0]
33
+ target_id = edge[1]
34
+ edge_data = edge[2]
35
+
36
+ if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
37
+ logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
38
+ return source_id, target_id, edge_data
39
+
40
+ description = edge_data["description"]
41
+
42
+ try:
43
+ descriptions = await rephrase_storage.get_by_id(description)
44
+ assert descriptions is not None
45
+
46
+ judgements = []
47
+ gts = [gt for _, gt in descriptions]
48
+ for description, gt in descriptions:
49
+ judgement = await trainee_llm_client.generate_topk_per_token(
50
+ STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
51
+ )
52
+ judgements.append(judgement[0].top_candidates)
53
+
54
+ loss = yes_no_loss_entropy(judgements, gts)
55
+
56
+ logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss)
57
+
58
+ edge_data["loss"] = loss
59
+ except Exception as e: # pylint: disable=broad-except
60
+ logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e)
61
+ logger.info("Use default loss 0.1")
62
+ edge_data["loss"] = -math.log(0.1)
63
+
64
+ await graph_storage.update_edge(source_id, target_id, edge_data)
65
+ return source_id, target_id, edge_data
66
+
67
+ edges = await graph_storage.get_all_edges()
68
+
69
+ results = []
70
+ for result in tqdm_async(
71
+ asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
72
+ total=len(edges),
73
+ desc="Judging relations"
74
+ ):
75
+ results.append(await result)
76
+
77
+ async def _judge_single_entity(
78
+ node: tuple,
79
+ ):
80
+ async with semaphore:
81
+ node_id = node[0]
82
+ node_data = node[1]
83
+
84
+ if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
85
+ logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
86
+ return node_id, node_data
87
+
88
+ description = node_data["description"]
89
+
90
+ try:
91
+ descriptions = await rephrase_storage.get_by_id(description)
92
+ assert descriptions is not None
93
+
94
+ judgements = []
95
+ gts = [gt for _, gt in descriptions]
96
+ for description, gt in descriptions:
97
+ judgement = await trainee_llm_client.generate_topk_per_token(
98
+ STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
99
+ )
100
+ judgements.append(judgement[0].top_candidates)
101
+
102
+ loss = yes_no_loss_entropy(judgements, gts)
103
+
104
+ logger.info("Node %s description: %s loss: %s", node_id, description, loss)
105
+
106
+ node_data["loss"] = loss
107
+ except Exception as e: # pylint: disable=broad-except
108
+ logger.error("Error in judging entity %s: %s", node_id, e)
109
+ logger.info("Use default loss 0.1")
110
+ node_data["loss"] = -math.log(0.1)
111
+
112
+ await graph_storage.update_node(node_id, node_data)
113
+ return node_id, node_data
114
+
115
+ nodes = await graph_storage.get_all_nodes()
116
+
117
+ results = []
118
+ for result in tqdm_async(
119
+ asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
120
+ total=len(nodes),
121
+ desc="Judging entities"
122
+ ):
123
+ results.append(await result)
124
+
125
+ return graph_storage
126
+
127
+ async def skip_judge_statement(
128
+ graph_storage: NetworkXStorage,
129
+ max_concurrent: int = 1000
130
+ ):
131
+ """
132
+ Skip the judgement of the statement
133
+ :param graph_storage: graph storage instance
134
+ :param max_concurrent: max concurrent
135
+ :return:
136
+ """
137
+ semaphore = asyncio.Semaphore(max_concurrent)
138
+
139
+ async def _skip_single_relation(
140
+ edge: tuple,
141
+ ):
142
+ async with semaphore:
143
+ source_id = edge[0]
144
+ target_id = edge[1]
145
+ edge_data = edge[2]
146
+
147
+ if "loss" in edge_data and edge_data["loss"] is not None:
148
+ logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
149
+ return source_id, target_id, edge_data
150
+
151
+ edge_data["loss"] = -math.log(0.1)
152
+ await graph_storage.update_edge(source_id, target_id, edge_data)
153
+ return source_id, target_id, edge_data
154
+
155
+ edges = await graph_storage.get_all_edges()
156
+ results = []
157
+ for result in tqdm_async(
158
+ asyncio.as_completed([_skip_single_relation(edge) for edge in edges]),
159
+ total=len(edges),
160
+ desc="Skipping judgement of relations"
161
+ ):
162
+ results.append(await result)
163
+
164
+ async def _skip_single_entity(
165
+ node: tuple,
166
+ ):
167
+ async with semaphore:
168
+ node_id = node[0]
169
+ node_data = node[1]
170
+
171
+ if "loss" in node_data and node_data["loss"] is not None:
172
+ logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
173
+ return node_id, node_data
174
+
175
+ node_data["loss"] = -math.log(0.1)
176
+ await graph_storage.update_node(node_id, node_data)
177
+ return node_id, node_data
178
+
179
+ nodes = await graph_storage.get_all_nodes()
180
+ results = []
181
+ for result in tqdm_async(
182
+ asyncio.as_completed([_skip_single_entity(node) for node in nodes]),
183
+ total=len(nodes),
184
+ desc="Skipping judgement of entities"
185
+ ):
186
+ results.append(await result)
187
+
188
+ return graph_storage
graphgen/operators/merge_kg.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import asyncio
3
+ from tqdm.asyncio import tqdm as tqdm_async
4
+
5
+ from graphgen.utils.format import split_string_by_multi_markers
6
+ from graphgen.utils import logger, detect_main_language
7
+ from graphgen.models import TopkTokenModel, Tokenizer
8
+ from graphgen.models.storage.base_storage import BaseGraphStorage
9
+ from graphgen.templates import KG_SUMMARIZATION_PROMPT, KG_EXTRACTION_PROMPT
10
+
11
+ async def _handle_kg_summary(
12
+ entity_or_relation_name: str,
13
+ description: str,
14
+ llm_client: TopkTokenModel,
15
+ tokenizer_instance: Tokenizer,
16
+ max_summary_tokens: int = 200
17
+ ) -> str:
18
+ """
19
+ 处理实体或关系的描述信息
20
+
21
+ :param entity_or_relation_name
22
+ :param description
23
+ :param llm_client
24
+ :param tokenizer_instance
25
+ :param max_summary_tokens
26
+ :return: new description
27
+ """
28
+ language = detect_main_language(description)
29
+ if language == "en":
30
+ language = "English"
31
+ else:
32
+ language = "Chinese"
33
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
34
+
35
+ tokens = tokenizer_instance.encode_string(description)
36
+ if len(tokens) < max_summary_tokens:
37
+ return description
38
+
39
+ use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
40
+ prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
41
+ entity_name=entity_or_relation_name,
42
+ description_list=use_description.split('<SEP>'),
43
+ **KG_SUMMARIZATION_PROMPT["FORMAT"]
44
+ )
45
+ new_description = await llm_client.generate_answer(prompt)
46
+ logger.info("Entity or relation %s summary: %s", entity_or_relation_name, new_description)
47
+ return new_description
48
+
49
+
50
+ async def merge_nodes(
51
+ nodes_data: dict,
52
+ kg_instance: BaseGraphStorage,
53
+ llm_client: TopkTokenModel,
54
+ tokenizer_instance: Tokenizer,
55
+ max_concurrent: int = 1000
56
+ ):
57
+ """
58
+ Merge nodes
59
+
60
+ :param nodes_data
61
+ :param kg_instance
62
+ :param llm_client
63
+ :param tokenizer_instance
64
+ :param max_concurrent
65
+ :return
66
+ """
67
+
68
+ semaphore = asyncio.Semaphore(max_concurrent)
69
+
70
+ async def process_single_node(entity_name: str, node_data: list[dict]):
71
+ async with semaphore:
72
+ entity_types = []
73
+ source_ids = []
74
+ descriptions = []
75
+
76
+ node = await kg_instance.get_node(entity_name)
77
+ if node is not None:
78
+ entity_types.append(node["entity_type"])
79
+ source_ids.extend(
80
+ split_string_by_multi_markers(node["source_id"], ['<SEP>'])
81
+ )
82
+ descriptions.append(node["description"])
83
+
84
+ # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
85
+ entity_type = sorted(
86
+ Counter(
87
+ [dp["entity_type"] for dp in node_data] + entity_types
88
+ ).items(),
89
+ key=lambda x: x[1],
90
+ reverse=True,
91
+ )[0][0]
92
+
93
+ description = '<SEP>'.join(
94
+ sorted(set([dp["description"] for dp in node_data] + descriptions))
95
+ )
96
+ description = await _handle_kg_summary(
97
+ entity_name, description, llm_client, tokenizer_instance
98
+ )
99
+
100
+ source_id = '<SEP>'.join(
101
+ set([dp["source_id"] for dp in node_data] + source_ids)
102
+ )
103
+
104
+ node_data = {
105
+ "entity_type": entity_type,
106
+ "description": description,
107
+ "source_id": source_id
108
+ }
109
+ await kg_instance.upsert_node(
110
+ entity_name,
111
+ node_data=node_data
112
+ )
113
+ node_data["entity_name"] = entity_name
114
+ return node_data
115
+
116
+ logger.info("Inserting entities into storage...")
117
+ entities_data = []
118
+ for result in tqdm_async(
119
+ asyncio.as_completed(
120
+ [process_single_node(k, v) for k, v in nodes_data.items()]
121
+ ),
122
+ total=len(nodes_data),
123
+ desc="Inserting entities into storage",
124
+ unit="entity",
125
+ ):
126
+ try:
127
+ entities_data.append(await result)
128
+ except Exception as e: # pylint: disable=broad-except
129
+ logger.error("Error occurred while inserting entities into storage: %s", e)
130
+
131
+
132
+ async def merge_edges(
133
+ edges_data: dict,
134
+ kg_instance: BaseGraphStorage,
135
+ llm_client: TopkTokenModel,
136
+ tokenizer_instance: Tokenizer,
137
+ max_concurrent: int = 1000
138
+ ):
139
+ """
140
+ Merge edges
141
+
142
+ :param edges_data
143
+ :param kg_instance
144
+ :param llm_client
145
+ :param tokenizer_instance
146
+ :param max_concurrent
147
+ :return
148
+ """
149
+
150
+ semaphore = asyncio.Semaphore(max_concurrent)
151
+
152
+ async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]):
153
+ async with semaphore:
154
+ source_ids = []
155
+ descriptions = []
156
+
157
+ edge = await kg_instance.get_edge(src_id, tgt_id)
158
+ if edge is not None:
159
+ source_ids.extend(
160
+ split_string_by_multi_markers(edge["source_id"], ['<SEP>'])
161
+ )
162
+ descriptions.append(edge["description"])
163
+
164
+ description = '<SEP>'.join(
165
+ sorted(set([dp["description"] for dp in edge_data] + descriptions))
166
+ )
167
+ source_id = '<SEP>'.join(
168
+ set([dp["source_id"] for dp in edge_data] + source_ids)
169
+ )
170
+
171
+ for insert_id in [src_id, tgt_id]:
172
+ if not await kg_instance.has_node(insert_id):
173
+ await kg_instance.upsert_node(
174
+ insert_id,
175
+ node_data={
176
+ "source_id": source_id,
177
+ "description": description,
178
+ "entity_type": "UNKNOWN"
179
+ }
180
+ )
181
+
182
+ description = await _handle_kg_summary(
183
+ f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance
184
+ )
185
+
186
+ await kg_instance.upsert_edge(
187
+ src_id,
188
+ tgt_id,
189
+ edge_data={
190
+ "source_id": source_id,
191
+ "description": description
192
+ }
193
+ )
194
+
195
+ edge_data = {
196
+ "src_id": src_id,
197
+ "tgt_id": tgt_id,
198
+ "description": description
199
+ }
200
+ return edge_data
201
+
202
+ logger.info("Inserting relationships into storage...")
203
+ relationships_data = []
204
+ for result in tqdm_async(
205
+ asyncio.as_completed(
206
+ [process_single_edge(src_id, tgt_id, v) for (src_id, tgt_id), v in edges_data.items()]
207
+ ),
208
+ total=len(edges_data),
209
+ desc="Inserting relationships into storage",
210
+ unit="relationship",
211
+ ):
212
+ try:
213
+ relationships_data.append(await result)
214
+ except Exception as e: # pylint: disable=broad-except
215
+ logger.error("Error occurred while inserting relationships into storage: %s", e)
graphgen/operators/quiz.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections import defaultdict
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+ from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage
6
+ from graphgen.utils import logger, detect_main_language
7
+ from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
8
+
9
+
10
+ async def quiz(
11
+ synth_llm_client: OpenAIModel,
12
+ graph_storage: NetworkXStorage,
13
+ rephrase_storage: JsonKVStorage,
14
+ max_samples: int = 1,
15
+ max_concurrent: int = 1000) -> JsonKVStorage:
16
+ """
17
+ Get all edges and quiz them
18
+
19
+ :param synth_llm_client: generate statements
20
+ :param graph_storage: graph storage instance
21
+ :param rephrase_storage: rephrase storage instance
22
+ :param max_samples: max samples for each edge
23
+ :param max_concurrent: max concurrent
24
+ :return:
25
+ """
26
+
27
+ semaphore = asyncio.Semaphore(max_concurrent)
28
+
29
+ async def _process_single_quiz(
30
+ des: str,
31
+ prompt: str,
32
+ gt: str
33
+ ):
34
+ async with semaphore:
35
+ try:
36
+ # 如果在rephrase_storage中已经存在,直接取出
37
+ descriptions = await rephrase_storage.get_by_id(des)
38
+ if descriptions:
39
+ return None
40
+
41
+ new_description = await synth_llm_client.generate_answer(
42
+ prompt,
43
+ temperature=1
44
+ )
45
+ return {des: [(new_description, gt)]}
46
+
47
+ except Exception as e: # pylint: disable=broad-except
48
+ logger.error("Error when quizzing description %s: %s", des, e)
49
+ return None
50
+
51
+
52
+ edges = await graph_storage.get_all_edges()
53
+ nodes = await graph_storage.get_all_nodes()
54
+
55
+ results = defaultdict(list)
56
+ tasks = []
57
+ for edge in edges:
58
+ edge_data = edge[2]
59
+
60
+ description = edge_data["description"]
61
+ language = "English" if detect_main_language(description) == "en" else "Chinese"
62
+
63
+ results[description] = [(description, 'yes')]
64
+
65
+ for i in range(max_samples):
66
+ if i > 0:
67
+ tasks.append(
68
+ _process_single_quiz(description,
69
+ DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
70
+ input_sentence=description), 'yes')
71
+ )
72
+ tasks.append(_process_single_quiz(description,
73
+ DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
74
+ input_sentence=description), 'no'))
75
+
76
+ for node in nodes:
77
+ node_data = node[1]
78
+ description = node_data["description"]
79
+ language = "English" if detect_main_language(description) == "en" else "Chinese"
80
+
81
+ results[description] = [(description, 'yes')]
82
+
83
+ for i in range(max_samples):
84
+ if i > 0:
85
+ tasks.append(
86
+ _process_single_quiz(description,
87
+ DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
88
+ input_sentence=description), 'yes')
89
+ )
90
+ tasks.append(_process_single_quiz(description,
91
+ DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
92
+ input_sentence=description), 'no'))
93
+
94
+ for result in tqdm_async(
95
+ asyncio.as_completed(tasks),
96
+ total=len(tasks),
97
+ desc="Quizzing descriptions"
98
+ ):
99
+ new_result = await result
100
+ if new_result:
101
+ for key, value in new_result.items():
102
+ results[key].extend(value)
103
+
104
+ for key, value in results.items():
105
+ results[key] = list(set(value))
106
+ await rephrase_storage.upsert({key: results[key]})
107
+
108
+
109
+ return rephrase_storage
graphgen/operators/resolute_coreference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from graphgen.models import Chunk
3
+ from graphgen.models import OpenAIModel
4
+ from graphgen.templates import COREFERENCE_RESOLUTION_TEMPLATE
5
+ from graphgen.utils import detect_main_language
6
+
7
+ async def resolute_coreference(
8
+ llm_client: OpenAIModel,
9
+ chunks: List[Chunk]) -> List[Chunk]:
10
+ """
11
+ Resolute conference
12
+
13
+ :param llm_client: LLM model
14
+ :param chunks: List of chunks
15
+ :return: List of chunks
16
+ """
17
+
18
+ if len(chunks) == 0:
19
+ return chunks
20
+
21
+ results = [chunks[0]]
22
+
23
+ for _, chunk in enumerate(chunks[1:]):
24
+ language = detect_main_language(chunk.content)
25
+ result = await llm_client.generate_answer(
26
+ COREFERENCE_RESOLUTION_TEMPLATE[language].format(
27
+ reference = results[0].content,
28
+ input_sentence = chunk.content
29
+ )
30
+ )
31
+ results.append(Chunk(id=chunk.id, content=result))
32
+
33
+ return results
graphgen/operators/search_wikipedia.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from graphgen.models import WikiSearch, OpenAIModel
3
+ from graphgen.models.storage.base_storage import BaseGraphStorage
4
+ from graphgen.templates import SEARCH_JUDGEMENT_PROMPT
5
+ from graphgen.utils import logger
6
+
7
+
8
+ async def _process_single_entity(entity_name: str,
9
+ description: str,
10
+ llm_client: OpenAIModel,
11
+ wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]:
12
+ """
13
+ Process single entity
14
+
15
+ """
16
+ search_results = await wiki_search_client.search(entity_name)
17
+ if not search_results:
18
+ return entity_name, None
19
+ examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"])
20
+ search_results.append("None of the above")
21
+
22
+ search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)])
23
+ prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format(
24
+ examples=examples,
25
+ entity_name=entity_name,
26
+ description=description,
27
+ search_results=search_results_str,
28
+ )
29
+ response = await llm_client.generate_answer(prompt)
30
+
31
+ try:
32
+ response = response.strip()
33
+ response = int(response)
34
+ if response < 1 or response >= len(search_results):
35
+ response = None
36
+ else:
37
+ response = await wiki_search_client.summary(search_results[response - 1])
38
+ except ValueError:
39
+ response = None
40
+
41
+ logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response)
42
+
43
+ return entity_name, response
44
+
45
+ async def search_wikipedia(llm_client: OpenAIModel,
46
+ wiki_search_client: WikiSearch,
47
+ knowledge_graph_instance: BaseGraphStorage,) -> dict:
48
+ """
49
+ Search wikipedia for entities
50
+
51
+ :param llm_client: LLM model
52
+ :param wiki_search_client: wiki search client
53
+ :param knowledge_graph_instance: knowledge graph instance
54
+ :return: nodes with search results
55
+ """
56
+
57
+
58
+ nodes = await knowledge_graph_instance.get_all_nodes()
59
+ nodes = list(nodes)
60
+ wiki_data = {}
61
+
62
+ tasks = [
63
+ _process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client)
64
+ for node in nodes
65
+ ]
66
+
67
+ for task in asyncio.as_completed(tasks):
68
+ result = await task
69
+ wiki_data[result[0]] = result[1]
70
+
71
+ return wiki_data
graphgen/operators/split_graph.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ from tqdm.asyncio import tqdm as tqdm_async
4
+ from graphgen.utils import logger
5
+
6
+ from graphgen.models import NetworkXStorage, TraverseStrategy
7
+
8
+ async def _get_node_info(
9
+ node_id: str,
10
+ graph_storage: NetworkXStorage,
11
+ )-> dict:
12
+ """
13
+ Get node info
14
+
15
+ :param node_id: node id
16
+ :param graph_storage: graph storage instance
17
+ :return: node info
18
+ """
19
+ node_data = await graph_storage.get_node(node_id)
20
+ return {
21
+ "node_id": node_id,
22
+ **node_data
23
+ }
24
+
25
+
26
+ def _get_level_n_edges_by_max_width(
27
+ edge_adj_list: dict,
28
+ node_dict: dict,
29
+ edges: list,
30
+ nodes,
31
+ src_edge: tuple,
32
+ max_depth: int,
33
+ bidirectional: bool,
34
+ max_extra_edges: int,
35
+ edge_sampling: str,
36
+ loss_strategy: str = "only_edge"
37
+ ) -> list:
38
+ """
39
+ Get level n edges for an edge.
40
+ n is decided by max_depth in traverse_strategy
41
+
42
+ :param edge_adj_list
43
+ :param node_dict
44
+ :param edges
45
+ :param nodes
46
+ :param src_edge
47
+ :param max_depth
48
+ :param bidirectional
49
+ :param max_extra_edges
50
+ :param edge_sampling
51
+ :return: level n edges
52
+ """
53
+ src_id, tgt_id, _ = src_edge
54
+
55
+ level_n_edges = []
56
+
57
+ start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
58
+
59
+ while max_depth > 0 and max_extra_edges > 0:
60
+ max_depth -= 1
61
+
62
+ candidate_edges = [
63
+ edges[edge_id]
64
+ for node in start_nodes
65
+ for edge_id in edge_adj_list[node]
66
+ if not edges[edge_id][2].get("visited", False)
67
+ ]
68
+
69
+ if not candidate_edges:
70
+ break
71
+
72
+ if len(candidate_edges) >= max_extra_edges:
73
+ if loss_strategy == "both":
74
+ er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
75
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges]
76
+ elif loss_strategy == "only_edge":
77
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges]
78
+ else:
79
+ raise ValueError(f"Invalid loss strategy: {loss_strategy}")
80
+
81
+ for edge in candidate_edges:
82
+ level_n_edges.append(edge)
83
+ edge[2]["visited"] = True
84
+ break
85
+
86
+ max_extra_edges -= len(candidate_edges)
87
+ new_start_nodes = set()
88
+
89
+ for edge in candidate_edges:
90
+ level_n_edges.append(edge)
91
+ edge[2]["visited"] = True
92
+
93
+ if not edge[0] in start_nodes:
94
+ new_start_nodes.add(edge[0])
95
+ if not edge[1] in start_nodes:
96
+ new_start_nodes.add(edge[1])
97
+
98
+ start_nodes = new_start_nodes
99
+
100
+ return level_n_edges
101
+
102
+
103
+ def _get_level_n_edges_by_max_tokens(
104
+ edge_adj_list: dict,
105
+ node_dict: dict,
106
+ edges: list,
107
+ nodes: list,
108
+ src_edge: tuple,
109
+ max_depth: int,
110
+ bidirectional: bool,
111
+ max_tokens: int,
112
+ edge_sampling: str,
113
+ loss_strategy: str = "only_edge"
114
+ ) -> list:
115
+ """
116
+ Get level n edges for an edge.
117
+ n is decided by max_depth in traverse_strategy.
118
+
119
+ :param edge_adj_list
120
+ :param node_dict
121
+ :param edges
122
+ :param nodes
123
+ :param src_edge
124
+ :param max_depth
125
+ :param bidirectional
126
+ :param max_tokens
127
+ :param edge_sampling
128
+ :return: level n edges
129
+ """
130
+ src_id, tgt_id, src_edge_data = src_edge
131
+
132
+ max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"]
133
+ + nodes[node_dict[tgt_id]][1]["length"])
134
+
135
+ level_n_edges = []
136
+
137
+ start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
138
+ temp_nodes = {src_id, tgt_id}
139
+
140
+ while max_depth > 0 and max_tokens > 0:
141
+ max_depth -= 1
142
+
143
+ candidate_edges = [
144
+ edges[edge_id]
145
+ for node in start_nodes
146
+ for edge_id in edge_adj_list[node]
147
+ if not edges[edge_id][2].get("visited", False)
148
+ ]
149
+
150
+ if not candidate_edges:
151
+ break
152
+
153
+ if loss_strategy == "both":
154
+ er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
155
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)
156
+ elif loss_strategy == "only_edge":
157
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)
158
+ else:
159
+ raise ValueError(f"Invalid loss strategy: {loss_strategy}")
160
+
161
+ for edge in candidate_edges:
162
+ max_tokens -= edge[2]["length"]
163
+ if not edge[0] in temp_nodes:
164
+ max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
165
+ if not edge[1] in temp_nodes:
166
+ max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
167
+
168
+ if max_tokens < 0:
169
+ return level_n_edges
170
+
171
+ level_n_edges.append(edge)
172
+ edge[2]["visited"] = True
173
+ temp_nodes.add(edge[0])
174
+ temp_nodes.add(edge[1])
175
+
176
+ new_start_nodes = set()
177
+ for edge in candidate_edges:
178
+ if not edge[0] in start_nodes:
179
+ new_start_nodes.add(edge[0])
180
+ if not edge[1] in start_nodes:
181
+ new_start_nodes.add(edge[1])
182
+
183
+ start_nodes = new_start_nodes
184
+
185
+ return level_n_edges
186
+
187
+
188
+ def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
189
+ """
190
+ Sort edges with edge sampling strategy
191
+
192
+ :param er_tuples: [(nodes:list, edge:tuple)]
193
+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
194
+ :return: sorted edges
195
+ """
196
+ if edge_sampling == "random":
197
+ er_tuples = random.sample(er_tuples, len(er_tuples))
198
+ elif edge_sampling == "min_loss":
199
+ er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"])
200
+ elif edge_sampling == "max_loss":
201
+ er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
202
+ reverse=True)
203
+ else:
204
+ raise ValueError(f"Invalid edge sampling: {edge_sampling}")
205
+ edges = [edge for _, edge in er_tuples]
206
+ return edges
207
+
208
+ def _sort_edges(edges: list, edge_sampling: str) -> list:
209
+ """
210
+ Sort edges with edge sampling strategy
211
+
212
+ :param edges: total edges
213
+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
214
+ :return: sorted edges
215
+ """
216
+ if edge_sampling == "random":
217
+ random.shuffle(edges)
218
+ elif edge_sampling == "min_loss":
219
+ edges = sorted(edges, key=lambda x: x[2]["loss"])
220
+ elif edge_sampling == "max_loss":
221
+ edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
222
+ else:
223
+ raise ValueError(f"Invalid edge sampling: {edge_sampling}")
224
+ return edges
225
+
226
+ async def get_batches_with_strategy( # pylint: disable=too-many-branches
227
+ nodes: list,
228
+ edges: list,
229
+ graph_storage: NetworkXStorage,
230
+ traverse_strategy: TraverseStrategy
231
+ ):
232
+ expand_method = traverse_strategy.expand_method
233
+ if expand_method == "max_width":
234
+ logger.info("Using max width strategy")
235
+ elif expand_method == "max_tokens":
236
+ logger.info("Using max tokens strategy")
237
+ else:
238
+ raise ValueError(f"Invalid expand method: {expand_method}")
239
+
240
+ max_depth = traverse_strategy.max_depth
241
+ edge_sampling = traverse_strategy.edge_sampling
242
+
243
+ # 构建临接矩阵
244
+ edge_adj_list = defaultdict(list)
245
+ node_dict = {}
246
+ processing_batches = []
247
+
248
+ node_cache = {}
249
+
250
+ async def get_cached_node_info(node_id: str) -> dict:
251
+ if node_id not in node_cache:
252
+ node_cache[node_id] = await _get_node_info(node_id, graph_storage)
253
+ return node_cache[node_id]
254
+
255
+ for i, (node_name, _) in enumerate(nodes):
256
+ node_dict[node_name] = i
257
+
258
+ if traverse_strategy.loss_strategy == "both":
259
+ er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
260
+ edges = _sort_tuples(er_tuples, edge_sampling)
261
+ elif traverse_strategy.loss_strategy == "only_edge":
262
+ edges = _sort_edges(edges, edge_sampling)
263
+ else:
264
+ raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
265
+
266
+ for i, (src, tgt, _) in enumerate(edges):
267
+ edge_adj_list[src].append(i)
268
+ edge_adj_list[tgt].append(i)
269
+
270
+ for edge in tqdm_async(edges, desc="Preparing batches"):
271
+ if "visited" in edge[2] and edge[2]["visited"]:
272
+ continue
273
+
274
+ edge[2]["visited"] = True
275
+
276
+ _process_nodes = []
277
+ _process_edges = []
278
+
279
+ src_id = edge[0]
280
+ tgt_id = edge[1]
281
+
282
+ _process_nodes.extend([await get_cached_node_info(src_id),
283
+ await get_cached_node_info(tgt_id)])
284
+ _process_edges.append(edge)
285
+
286
+ if expand_method == "max_width":
287
+ level_n_edges = _get_level_n_edges_by_max_width(
288
+ edge_adj_list, node_dict, edges, nodes, edge, max_depth,
289
+ traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
290
+ edge_sampling, traverse_strategy.loss_strategy
291
+ )
292
+ else:
293
+ level_n_edges = _get_level_n_edges_by_max_tokens(
294
+ edge_adj_list, node_dict, edges, nodes, edge, max_depth,
295
+ traverse_strategy.bidirectional, traverse_strategy.max_tokens,
296
+ edge_sampling, traverse_strategy.loss_strategy
297
+ )
298
+
299
+ for _edge in level_n_edges:
300
+ _process_nodes.append(await get_cached_node_info(_edge[0]))
301
+ _process_nodes.append(await get_cached_node_info(_edge[1]))
302
+ _process_edges.append(_edge)
303
+
304
+ # 去重
305
+ _process_nodes = list({node['node_id']: node for node in _process_nodes}.values())
306
+ _process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values())
307
+
308
+ processing_batches.append((_process_nodes, _process_edges))
309
+
310
+ logger.info("Processing batches: %d", len(processing_batches))
311
+
312
+ # isolate nodes
313
+ isolated_node_strategy = traverse_strategy.isolated_node_strategy
314
+ if isolated_node_strategy == "add":
315
+ processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage)
316
+ logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches))
317
+
318
+ return processing_batches
319
+
320
+ async def _add_isolated_nodes(
321
+ nodes: list,
322
+ processing_batches: list,
323
+ graph_storage: NetworkXStorage,
324
+ ) -> list:
325
+ visited_nodes = set()
326
+ for _process_nodes, _process_edges in processing_batches:
327
+ for node in _process_nodes:
328
+ visited_nodes.add(node["node_id"])
329
+ for node in nodes:
330
+ if node[0] not in visited_nodes:
331
+ _process_nodes = [await _get_node_info(node[0], graph_storage)]
332
+ processing_batches.append((_process_nodes, []))
333
+ return processing_batches
graphgen/operators/traverse_graph.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import gradio as gr
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
7
+ from graphgen.templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT
8
+ from graphgen.utils import detect_main_language, compute_content_hash, logger
9
+ from graphgen.operators.split_graph import get_batches_with_strategy
10
+
11
+
12
+ async def _pre_tokenize(graph_storage: NetworkXStorage,
13
+ tokenizer: Tokenizer,
14
+ edges: list,
15
+ nodes: list) -> tuple:
16
+
17
+ sem = asyncio.Semaphore(1000)
18
+ async def handle_edge(edge: tuple) -> tuple:
19
+ async with sem:
20
+ if 'length' not in edge[2]:
21
+ edge[2]['length'] = len(
22
+ await asyncio.get_event_loop().run_in_executor(None,
23
+ tokenizer.encode_string,
24
+ edge[2]['description']))
25
+ return edge
26
+
27
+ async def handle_node(node: dict) -> dict:
28
+ async with sem:
29
+ if 'length' not in node[1]:
30
+ node[1]['length'] = len(
31
+ await asyncio.get_event_loop().run_in_executor(None,
32
+ tokenizer.encode_string,
33
+ node[1]['description']))
34
+ return node
35
+
36
+ new_edges = []
37
+ new_nodes = []
38
+
39
+ for result in tqdm_async(asyncio.as_completed([handle_edge(edge) for edge in edges]),
40
+ total=len(edges), desc="Pre-tokenizing edges"):
41
+ new_edge = await result
42
+ await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
43
+ new_edges.append(new_edge)
44
+
45
+ for result in tqdm_async(asyncio.as_completed([handle_node(node) for node in nodes]),
46
+ total=len(nodes), desc="Pre-tokenizing nodes"):
47
+ new_node = await result
48
+ await graph_storage.update_node(new_node[0], new_node[1])
49
+ new_nodes.append(new_node)
50
+
51
+ await graph_storage.index_done_callback()
52
+ return new_edges, new_nodes
53
+
54
+ async def _construct_rephrasing_prompt(_process_nodes: list,
55
+ _process_edges: list,
56
+ text_chunks_storage: JsonKVStorage,
57
+ add_context: bool = False
58
+ ) -> str:
59
+ entities = [
60
+ f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
61
+ ]
62
+ relations = [
63
+ f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
64
+ for _process_edge in _process_edges
65
+ ]
66
+
67
+ entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
68
+ relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
69
+ language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
70
+
71
+ if add_context:
72
+ original_ids = ([node['source_id'].split('<SEP>')[0] for node in _process_nodes] +
73
+ [edge[2]['source_id'].split('<SEP>')[0] for edge in _process_edges])
74
+
75
+ original_ids = list(set(original_ids))
76
+ original_text = await text_chunks_storage.get_by_ids(original_ids)
77
+ original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
78
+
79
+ prompt = ANSWER_REPHRASING_PROMPT[language]['CONTEXT_TEMPLATE'].format(
80
+ language=language,
81
+ original_text=original_text,
82
+ entities=entities_str,
83
+ relationships=relations_str
84
+ )
85
+ return prompt
86
+
87
+ prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
88
+ language=language,
89
+ entities=entities_str,
90
+ relationships=relations_str
91
+ )
92
+ return prompt
93
+
94
+ def get_loss_tercile(losses: list) -> (float, float):
95
+ losses = sorted(losses)
96
+ q1_index = int(len(losses) * (1 / 3))
97
+ q2_index = int(len(losses) * (2 / 3))
98
+
99
+ return losses[q1_index], losses[q2_index]
100
+
101
+ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
102
+ if loss_strategy == "only_edge":
103
+ return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
104
+ if loss_strategy == "both":
105
+ return sum(edge[2]['loss'] for edge in batch[1]) + sum(node['loss'] for node in batch[0]) / \
106
+ (len(batch[0]) + len(batch[1]))
107
+ raise ValueError("Invalid loss strategy")
108
+
109
+ def _post_process_synthetic_data(data):
110
+ block = data.split("\n\n")
111
+ qas = []
112
+ for line in block:
113
+ if "Question:" in line and "Answer:" in line:
114
+ question = line.split("Question:")[1].split("Answer:")[0].strip()
115
+ answer = line.split("Answer:")[1].strip()
116
+ qas.append({
117
+ "question": question,
118
+ "answer": answer
119
+ })
120
+ elif "问题:" in line and "答案:" in line:
121
+ question = line.split("问题:")[1].split("答案:")[0].strip()
122
+ answer = line.split("答案:")[1].strip()
123
+ qas.append({
124
+ "question": question,
125
+ "answer": answer
126
+ })
127
+ elif "问题:" in line and "回答:" in line:
128
+ question = line.split("问题:")[1].split("回答:")[0].strip()
129
+ answer = line.split("回答:")[1].strip()
130
+ qas.append({
131
+ "question": question,
132
+ "answer": answer
133
+ })
134
+ return qas
135
+
136
+ async def traverse_graph_by_edge(
137
+ llm_client: OpenAIModel,
138
+ tokenizer: Tokenizer,
139
+ graph_storage: NetworkXStorage,
140
+ traverse_strategy: TraverseStrategy,
141
+ text_chunks_storage: JsonKVStorage,
142
+ progress_bar: gr.Progress = None,
143
+ max_concurrent: int = 1000
144
+ ) -> dict:
145
+ """
146
+ Traverse the graph
147
+
148
+ :param llm_client
149
+ :param tokenizer
150
+ :param graph_storage
151
+ :param traverse_strategy
152
+ :param text_chunks_storage
153
+ :param progress_bar
154
+ :param max_concurrent
155
+ :return: question and answer
156
+ """
157
+
158
+ semaphore = asyncio.Semaphore(max_concurrent)
159
+
160
+ async def _process_nodes_and_edges(
161
+ _process_nodes: list,
162
+ _process_edges: list,
163
+ ) -> str:
164
+ prompt = await _construct_rephrasing_prompt(
165
+ _process_nodes,
166
+ _process_edges,
167
+ text_chunks_storage,
168
+ add_context = False
169
+ )
170
+ context = await llm_client.generate_answer(prompt)
171
+
172
+ # post-process the context
173
+ if context.startswith("Rephrased Text:"):
174
+ context = context[len("Rephrased Text:"):].strip()
175
+ elif context.startswith("重述文本:"):
176
+ context = context[len("重述文本:"):].strip()
177
+
178
+ return context
179
+
180
+ async def _process_single_batch(
181
+ _process_batch: tuple,
182
+ question_type: str = "single"
183
+ ) -> dict:
184
+ async with semaphore:
185
+ context = await _process_nodes_and_edges(
186
+ _process_batch[0],
187
+ _process_batch[1],
188
+ )
189
+
190
+ language = "Chinese" if detect_main_language(context) == "zh" else "English"
191
+ pre_length = sum(node['length'] for node in _process_batch[0]) \
192
+ + sum(edge[2]['length'] for edge in _process_batch[1])
193
+
194
+ if question_type == "single":
195
+ question = await llm_client.generate_answer(
196
+ QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
197
+ answer=context
198
+ )
199
+ )
200
+ if question.startswith("Question:"):
201
+ question = question[len("Question:"):].strip()
202
+ elif question.startswith("问题:"):
203
+ question = question[len("问题:"):].strip()
204
+
205
+ logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
206
+ logger.info("Pre-length: %s", pre_length)
207
+ logger.info("Question: %s", question)
208
+ logger.info("Answer: %s", context)
209
+
210
+ return {
211
+ compute_content_hash(context): {
212
+ "question": question,
213
+ "answer": context,
214
+ "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
215
+ }
216
+ }
217
+
218
+ content = await llm_client.generate_answer(
219
+ QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
220
+ doc=context
221
+ )
222
+ )
223
+ qas = _post_process_synthetic_data(content)
224
+
225
+ if len(qas) == 0:
226
+ print(content)
227
+ logger.error("Error occurred while processing batch, question or answer is None")
228
+ return {}
229
+
230
+ final_results = {}
231
+ logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
232
+ logger.info("Pre-length: %s", pre_length)
233
+ for qa in qas:
234
+ logger.info("Question: %s", qa['question'])
235
+ logger.info("Answer: %s", qa['answer'])
236
+ final_results[compute_content_hash(qa['question'])] = {
237
+ "question": qa['question'],
238
+ "answer": qa['answer'],
239
+ "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
240
+ }
241
+ return final_results
242
+
243
+ results = {}
244
+ edges = list(await graph_storage.get_all_edges())
245
+ nodes = list(await graph_storage.get_all_nodes())
246
+
247
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
248
+
249
+ processing_batches = await get_batches_with_strategy(
250
+ nodes,
251
+ edges,
252
+ graph_storage,
253
+ traverse_strategy
254
+ )
255
+
256
+ for result in tqdm_async(asyncio.as_completed(
257
+ [_process_single_batch(batch) for batch in processing_batches]
258
+ ), total=len(processing_batches), desc="[4/4]Generating QAs"):
259
+ try:
260
+ if progress_bar is not None:
261
+ progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
262
+ results.update(await result)
263
+ if progress_bar is not None and len(results) == len(processing_batches):
264
+ progress_bar(1, desc="[4/4]Generating QAs")
265
+ except Exception as e: # pylint: disable=broad-except
266
+ logger.error("Error occurred while generating QA: %s", e)
267
+
268
+ return results
269
+
270
+
271
+ async def traverse_graph_atomically(
272
+ llm_client: OpenAIModel,
273
+ tokenizer: Tokenizer,
274
+ graph_storage: NetworkXStorage,
275
+ traverse_strategy: TraverseStrategy,
276
+ text_chunks_storage: JsonKVStorage,
277
+ progress_bar: gr.Progress = None,
278
+ max_concurrent: int = 1000
279
+ ) -> dict:
280
+ """
281
+ Traverse the graph atomicly
282
+
283
+ :param llm_client
284
+ :param tokenizer
285
+ :param graph_storage
286
+ :param traverse_strategy
287
+ :param text_chunks_storage
288
+ :param progress_bar
289
+ :param max_concurrent
290
+ :return: question and answer
291
+ """
292
+ assert traverse_strategy.qa_form == "atomic"
293
+
294
+ semaphore = asyncio.Semaphore(max_concurrent)
295
+ async def _generate_question(
296
+ node_or_edge: tuple
297
+ ):
298
+ if len(node_or_edge) == 2:
299
+ des = node_or_edge[0] + ": " + node_or_edge[1]['description']
300
+ loss = node_or_edge[1]['loss']
301
+ else:
302
+ des = node_or_edge[2]['description']
303
+ loss = node_or_edge[2]['loss']
304
+
305
+ async with semaphore:
306
+ try:
307
+ language = "Chinese" if detect_main_language(des) == "zh" else "English"
308
+
309
+ qa = await llm_client.generate_answer(
310
+ QUESTION_GENERATION_PROMPT[language]['SINGLE_QA_TEMPLATE'].format(
311
+ doc=des
312
+ )
313
+ )
314
+
315
+ if "Question:" in qa and "Answer:" in qa:
316
+ question = qa.split("Question:")[1].split("Answer:")[0].strip()
317
+ answer = qa.split("Answer:")[1].strip()
318
+ elif "问题:" in qa and "答案:" in qa:
319
+ question = qa.split("问题:")[1].split("答案:")[0].strip()
320
+ answer = qa.split("答案:")[1].strip()
321
+ else:
322
+ return {}
323
+
324
+ question = question.strip("\"")
325
+ answer = answer.strip("\"")
326
+
327
+ logger.info("Question: %s", question)
328
+ logger.info("Answer: %s", answer)
329
+ return {
330
+ compute_content_hash(question): {
331
+ "question": question,
332
+ "answer": answer,
333
+ "loss": loss
334
+ }
335
+ }
336
+ except Exception as e: # pylint: disable=broad-except
337
+ logger.error("Error occurred while generating question: %s", e)
338
+ return {}
339
+
340
+ results = {}
341
+ edges = list(await graph_storage.get_all_edges())
342
+ nodes = list(await graph_storage.get_all_nodes())
343
+
344
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
345
+
346
+ tasks = []
347
+ for node in nodes:
348
+ if "<SEP>" in node[1]['description']:
349
+ description_list = node[1]['description'].split("<SEP>")
350
+ for item in description_list:
351
+ tasks.append((node[0], {"description": item, 'loss': node[1]['loss']}))
352
+ else:
353
+ tasks.append((node[0], node[1]))
354
+ for edge in edges:
355
+ if "<SEP>" in edge[2]['description']:
356
+ description_list = edge[2]['description'].split("<SEP>")
357
+ for item in description_list:
358
+ tasks.append((edge[0], edge[1], {"description": item, 'loss': edge[2]['loss']}))
359
+ else:
360
+ tasks.append((edge[0], edge[1], edge[2]))
361
+
362
+ for result in tqdm_async(
363
+ asyncio.as_completed([_generate_question(task) for task in tasks]),
364
+ total=len(tasks),
365
+ desc="[4/4]Generating QAs"
366
+ ):
367
+ try:
368
+ if progress_bar is not None:
369
+ progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
370
+ results.update(await result)
371
+ if progress_bar is not None and len(results) == len(tasks):
372
+ progress_bar(1, desc="[4/4]Generating QAs")
373
+ except Exception as e: # pylint: disable=broad-except
374
+ logger.error("Error occurred while generating QA: %s", e)
375
+ return results
376
+
377
+ async def traverse_graph_for_multi_hop(
378
+ llm_client: OpenAIModel,
379
+ tokenizer: Tokenizer,
380
+ graph_storage: NetworkXStorage,
381
+ traverse_strategy: TraverseStrategy,
382
+ text_chunks_storage: JsonKVStorage,
383
+ progress_bar: gr.Progress = None,
384
+ max_concurrent: int = 1000
385
+ ) -> dict:
386
+ """
387
+ Traverse the graph for multi-hop
388
+
389
+ :param llm_client
390
+ :param tokenizer
391
+ :param graph_storage
392
+ :param traverse_strategy
393
+ :param text_chunks_storage
394
+ :param progress_bar
395
+ :param max_concurrent
396
+ :return: question and answer
397
+ """
398
+ assert traverse_strategy.qa_form == "multi_hop"
399
+
400
+ semaphore = asyncio.Semaphore(max_concurrent)
401
+
402
+ results = {}
403
+ edges = list(await graph_storage.get_all_edges())
404
+ nodes = list(await graph_storage.get_all_nodes())
405
+
406
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
407
+
408
+ processing_batches = await get_batches_with_strategy(
409
+ nodes,
410
+ edges,
411
+ graph_storage,
412
+ traverse_strategy
413
+ )
414
+
415
+ async def _process_single_batch(
416
+ _process_batch: tuple
417
+ ) -> dict:
418
+ async with semaphore:
419
+ try:
420
+ language = "Chinese" if detect_main_language(_process_batch[0][0]['description']) == "zh" else "English"
421
+
422
+ _process_nodes = _process_batch[0]
423
+ _process_edges = _process_batch[1]
424
+
425
+ entities = [
426
+ f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
427
+ ]
428
+
429
+ relations = [
430
+ f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
431
+ for _process_edge in _process_edges
432
+ ]
433
+
434
+ entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
435
+ relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
436
+
437
+ prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
438
+ entities=entities_str,
439
+ relationships=relations_str
440
+ )
441
+
442
+ context = await llm_client.generate_answer(prompt)
443
+
444
+ # post-process the context
445
+ if "Question:" in context and "Answer:" in context:
446
+ question = context.split("Question:")[1].split("Answer:")[0].strip()
447
+ answer = context.split("Answer:")[1].strip()
448
+ elif "问题:" in context and "答案:" in context:
449
+ question = context.split("问题:")[1].split("答案:")[0].strip()
450
+ answer = context.split("答案:")[1].strip()
451
+ else:
452
+ return {}
453
+
454
+ question = question.strip("\"")
455
+ answer = answer.strip("\"")
456
+
457
+ logger.info("Question: %s", question)
458
+ logger.info("Answer: %s", answer)
459
+
460
+ return {
461
+ compute_content_hash(question): {
462
+ "question": question,
463
+ "answer": answer,
464
+ "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
465
+ }
466
+ }
467
+
468
+ except Exception as e: # pylint: disable=broad-except
469
+ logger.error("Error occurred while processing batch: %s", e)
470
+ return {}
471
+
472
+ async for result in tqdm_async(
473
+ asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
474
+ total=len(processing_batches),
475
+ desc="[4/4]Generating QAs"
476
+ ):
477
+ try:
478
+ if progress_bar is not None:
479
+ progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
480
+ results.update(await result)
481
+ if progress_bar is not None and len(results) == len(processing_batches):
482
+ progress_bar(1, desc="[4/4]Generating QAs")
483
+ except Exception as e: # pylint: disable=broad-except
484
+ logger.error("Error occurred while generating QA: %s", e)
485
+ return results
graphgen/templates/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .kg_extraction import KG_EXTRACTION_PROMPT
2
+ from .kg_summarization import KG_SUMMARIZATION_PROMPT
3
+ from .search_judgement import SEARCH_JUDGEMENT_PROMPT
4
+ from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
5
+ from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
6
+ from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
7
+ from .question_generation import QUESTION_GENERATION_PROMPT
8
+ from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
9
+ from .coreference_resolution import COREFERENCE_RESOLUTION_TEMPLATE
graphgen/templates/answer_rephrasing.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEMPLATE_CONTEXT_EN: str = """---Role---
2
+
3
+ You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements.
4
+ Use {language} as output language.
5
+
6
+ ---Goal---
7
+ To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
8
+ 1. Following a clear logical flow and structure
9
+ 2. Establishing proper cause-and-effect relationships
10
+ 3. Ensuring temporal and sequential consistency
11
+ 4. Creating smooth transitions between ideas using conjunctions and appropriate linking words like "firstly," "however," "therefore," etc.
12
+
13
+ ---Instructions---
14
+ 1. Analyze the provided ENTITIES and RELATIONSHIPS carefully to identify:
15
+ - Key concepts and their hierarchies
16
+ - Temporal sequences and chronological order
17
+ - Cause-and-effect relationships
18
+ - Dependencies between different elements
19
+
20
+ 2. Organize the information in a logical sequence by:
21
+ - Starting with foundational concepts
22
+ - Building up to more complex relationships
23
+ - Grouping related ideas together
24
+ - Creating clear transitions between sections
25
+
26
+ 3. Rephrase the text while maintaining:
27
+ - Logical flow and progression
28
+ - Clear connections between ideas
29
+ - Proper context and background
30
+ - Coherent narrative structure
31
+
32
+ 4. Review and refine the text to ensure:
33
+ - Logical consistency throughout
34
+ - Clear cause-and-effect relationships
35
+
36
+ ################
37
+ -ORIGINAL TEXT-
38
+ ################
39
+ {original_text}
40
+
41
+ ################
42
+ -ENTITIES-
43
+ ################
44
+ {entities}
45
+
46
+ ################
47
+ -RELATIONSHIPS-
48
+ ################
49
+ {relationships}
50
+
51
+ """
52
+
53
+ TEMPLATE_CONTEXT_ZH: str = """---角色---
54
+
55
+ 你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。
56
+ 使用{language}作为输出语言。
57
+
58
+ ---目标---
59
+
60
+ 生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
61
+ 1. 遵循清晰的逻辑流和结构
62
+ 2. 建立适当的因果关系
63
+ 3. 确保时间和顺序的一致性
64
+ 4. 使用连词和适当的连接词(如"首先"、"然而"、"因此"等)创造流畅的过渡
65
+
66
+ ---说明---
67
+ 1. 仔细分析提供的实体和关系,以识别:
68
+ - 关键概念及其层级关系
69
+ - 时间序列和时间顺序
70
+ - 因果关系
71
+ - 不同元素之间的依赖关系
72
+ 2. 通过以下方式将信息组织成逻辑顺序:
73
+ - 从基础概念开始
74
+ - 逐步建立更复杂的关系
75
+ - 将相关的想法分组在一起
76
+ - 在各部分之间创建清晰的过渡
77
+ 3. 重述文本时保持:
78
+ - 逻辑流畅
79
+ - 概念之间的清晰联系
80
+ - 适当的上下文和背景
81
+ - 连贯的叙述结构
82
+ 4. 检查和完善文本以确保:
83
+ - 整体逻辑一致性
84
+ - 清晰的因果关系
85
+
86
+ ################
87
+ -原始文本-
88
+ ################
89
+ {original_text}
90
+
91
+ ################
92
+ -实体-
93
+ ################
94
+ {entities}
95
+
96
+ ################
97
+ -关系-
98
+ ################
99
+ {relationships}
100
+
101
+ """
102
+
103
+ TEMPLATE_EN: str = """---Role---
104
+
105
+ You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below.
106
+ Use {language} as output language.
107
+
108
+ ---Goal---
109
+ To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
110
+ 1. Following a clear logical flow and structure
111
+ 2. Establishing proper cause-and-effect relationships
112
+ 3. Ensuring temporal and sequential consistency
113
+ 4. Creating smooth transitions between ideas using conjunctions and appropriate linking words like "firstly," "however," "therefore," etc.
114
+
115
+ ---Instructions---
116
+ 1. Analyze the provided ENTITIES and RELATIONSHIPS carefully to identify:
117
+ - Key concepts and their hierarchies
118
+ - Temporal sequences and chronological order
119
+ - Cause-and-effect relationships
120
+ - Dependencies between different elements
121
+
122
+ 2. Organize the information in a logical sequence by:
123
+ - Starting with foundational concepts
124
+ - Building up to more complex relationships
125
+ - Grouping related ideas together
126
+ - Creating clear transitions between sections
127
+
128
+ 3. Rephrase the text while maintaining:
129
+ - Logical flow and progression
130
+ - Clear connections between ideas
131
+ - Proper context and background
132
+ - Coherent narrative structure
133
+
134
+ 4. Review and refine the text to ensure:
135
+ - Logical consistency throughout
136
+ - Clear cause-and-effect relationships
137
+
138
+ ################
139
+ -ENTITIES-
140
+ ################
141
+ {entities}
142
+
143
+ ################
144
+ -RELATIONSHIPS-
145
+ ################
146
+ {relationships}
147
+
148
+ """
149
+
150
+ TEMPLATE_ZH: str = """---角色---
151
+
152
+ 你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。
153
+ 使用{language}作为输出语言。
154
+
155
+ ---目标---
156
+
157
+ 生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
158
+ 1. 遵循清晰的逻辑流和结构
159
+ 2. 建立适当的因果关系
160
+ 3. 确保时间和顺序的一致性
161
+ 4. 使用连词和适当的连接词(如"首先"、"然而"、"因此"等)创造流畅的过渡
162
+
163
+ ---说明---
164
+ 1. 仔细分析提供的实体和关系,以识别:
165
+ - 关键概念及其层级关系
166
+ - 时间序列和时间顺序
167
+ - 因果关系
168
+ - 不同元素之间的依赖关系
169
+ 2. 通过以下方式将信息组织成逻辑顺序:
170
+ - 从基础概念开始
171
+ - 逐步建立更复杂的关系
172
+ - 将相关的想法分组在一起
173
+ - 在各部分之间创建清晰的过渡
174
+ 3. 重述文本时保持:
175
+ - 逻辑流畅
176
+ - 概念之间的清晰联系
177
+ - 适当的上下文和背景
178
+ - 连贯的叙述结构
179
+ 4. 检查和完善文本以确保:
180
+ - 整体逻辑一致性
181
+ - 清晰的因果关系
182
+
183
+ ################
184
+ -实体-
185
+ ################
186
+ {entities}
187
+
188
+ ################
189
+ -关系-
190
+ ################
191
+ {relationships}
192
+
193
+ """
194
+
195
+ REQUIREMENT_ZH = """
196
+ ################
197
+ 请在下方直接输出连贯的重述文本,不要输出任何额外的内容。
198
+
199
+ 重述文本:
200
+ """
201
+
202
+ REQUIREMENT_EN = """
203
+ ################
204
+ Please directly output the coherent rephrased text below, without any additional content.
205
+
206
+ Rephrased Text:
207
+ """
208
+
209
+
210
+ ANSWER_REPHRASING_PROMPT= {
211
+ "English": {
212
+ "TEMPLATE": TEMPLATE_EN + REQUIREMENT_EN,
213
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_EN + REQUIREMENT_EN
214
+ },
215
+ "Chinese": {
216
+ "TEMPLATE": TEMPLATE_ZH + REQUIREMENT_ZH,
217
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_ZH + REQUIREMENT_ZH
218
+ }
219
+ }
graphgen/templates/coreference_resolution.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=C0301
2
+ TEMPLATE_ZH: str = """请根据参考文本识别并消解文本中的指代词,明确每个代词所指代的具体实体,并直接输出消解后的文本。
3
+
4
+ -示例-
5
+ 输入:
6
+ 小明和小红一起去公园。她们玩得很开心。之后,他们去吃冰淇淋。
7
+ 输出:
8
+ 小明和小红一起去公园。小明和小红玩得很开心。之后,小明和小红去吃冰淇淋。
9
+
10
+ -真实数据-
11
+ 参考文本:
12
+ {reference}
13
+ 输入:
14
+ {input_sentence}
15
+ 请直接输出改写后的句子,不要输出任何额外信息。
16
+ 输出:
17
+ """
18
+
19
+ TEMPLATE_EN: str = """Please identify and resolve the pronouns in the reference text, specify the specific entities referred to by each pronoun, and directly output the resolved text.
20
+
21
+ -Example-
22
+ Input:
23
+ John and Mary went to the park. They had a great time. Later, they went to eat ice cream.
24
+ Output:
25
+ John and Mary went to the park. John and Mary had a great time. Later, John and Mary went to eat ice cream.
26
+
27
+ -Real Data-
28
+ Reference text:
29
+ {reference}
30
+ Input:
31
+ {input_sentence}
32
+ Please directly output the rewritten sentence without any additional information.
33
+ Output:
34
+ """
35
+
36
+ COREFERENCE_RESOLUTION_TEMPLATE = {
37
+ "en": TEMPLATE_EN,
38
+ "zh": TEMPLATE_ZH
39
+ }
graphgen/templates/description_rephrasing.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ANTI_TEMPLATE_EN: str = """-Goal-
2
+ Transform the input sentence into its opposite meaning while:
3
+
4
+ 1. Preserving most of the original sentence structure
5
+ 2. Changing only key words that affect the core meaning
6
+ 3. Maintaining the same tone and style
7
+ 4. The input sentence provided is a right description, and the output sentence should be a wrong description
8
+ 5. The output sentence should be fluent and grammatically correct
9
+
10
+ ################
11
+ -Examples-
12
+ ################
13
+ Input:
14
+ The bright sunshine made everyone feel energetic and happy.
15
+
16
+ Output:
17
+ The bright sunshine made everyone feel tired and sad.
18
+
19
+ ################
20
+ -Real Data-
21
+ ################
22
+ Input:
23
+ {input_sentence}
24
+ ################
25
+ Please directly output the rewritten sentence without any additional information.
26
+ Output:
27
+ """
28
+
29
+ ANTI_TEMPLATE_ZH: str = """-目标-
30
+ 将输入句子转换为相反含义的句子,同时:
31
+
32
+ 1. 保留大部分原始句子结构
33
+ 2. 仅更改影响核心含义的关键词
34
+ 3. 保持相同的语气和风格
35
+ 4. 提供的输入句子是一个正确的描述,输出句子应该是一个错误的描述
36
+ 5. 输出句子应该流畅且语法正确
37
+
38
+ ################
39
+ -示例-
40
+ ################
41
+ 输入:
42
+ 明亮的阳光让每个人都感到充满活力和快乐。
43
+
44
+ 输出:
45
+ 明亮的阳光让每个人都感到疲惫和悲伤。
46
+
47
+ ################
48
+ -真实数据-
49
+ ################
50
+ 输入:
51
+ {input_sentence}
52
+ ################
53
+ 请直接输出改写后的句子,不要输出任何额外信息。
54
+ 输出:
55
+ """
56
+
57
+ TEMPLATE_ZH: str = """-目标-
58
+ 将输入句子转换为相同含义的句子,同时:
59
+
60
+ 1. 保留大部分原始句子结构
61
+ 2. 仅更改影响核心含义的关键词
62
+ 3. 保持相同的语气和风格
63
+ 4. 输出句子应该流畅且语法正确
64
+
65
+ ################
66
+ -示例-
67
+ ################
68
+ 输入:
69
+ 明亮的阳光让每个人都感到充满活力和快乐。
70
+
71
+ 输出:
72
+ 明媚的阳光让每个人都感受到活力与快乐。
73
+
74
+ ################
75
+ -真实数据-
76
+ ################
77
+ 输入:
78
+ {input_sentence}
79
+ ################
80
+ 请直接输出改写后的句子,不要输出任何额外信息。
81
+ 输出:
82
+ """
83
+
84
+ TEMPLATE_EN: str = """-Goal-
85
+ Transform the input sentence into a sentence with the same meaning while:
86
+
87
+ 1. Preserving most of the original sentence structure
88
+ 2. Changing only key words that affect the core meaning
89
+ 3. Maintaining the same tone and style
90
+ 4. The output sentence should be fluent and grammatically correct
91
+
92
+ ################
93
+ -Examples-
94
+ ################
95
+ Input:
96
+ The bright sunshine made everyone feel energetic and happy.
97
+
98
+ Output:
99
+ The bright sunshine made everyone feel energetic and joyful.
100
+
101
+ ################
102
+ -Real Data-
103
+ ################
104
+ Input:
105
+ {input_sentence}
106
+ ################
107
+ Please directly output the rewritten sentence without any additional information.
108
+ Output:
109
+ """
110
+
111
+
112
+ DESCRIPTION_REPHRASING_PROMPT= {
113
+ "English": {
114
+ "ANTI_TEMPLATE": ANTI_TEMPLATE_EN,
115
+ "TEMPLATE": TEMPLATE_EN
116
+ },
117
+ "Chinese": {
118
+ "ANTI_TEMPLATE": ANTI_TEMPLATE_ZH,
119
+ "TEMPLATE": TEMPLATE_ZH
120
+ }
121
+ }
graphgen/templates/kg_extraction.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=C0301
2
+
3
+ TEMPLATE_EN: str = """You are an NLP expert, skilled at analyzing text to extract named entities and their relationships.
4
+
5
+ -Goal-
6
+ Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
7
+ Use {language} as output language.
8
+
9
+ -Steps-
10
+ 1. Identify all entities. For each identified entity, extract the following information:
11
+ - entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
12
+ - entity_type: One of the following types: [{entity_types}]
13
+ - entity_summary: Comprehensive summary of the entity's attributes and activities
14
+ Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_summary>)
15
+
16
+ 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
17
+ For each pair of related entities, extract the following information:
18
+ - source_entity: name of the source entity, as identified in step 1
19
+ - target_entity: name of the target entity, as identified in step 1
20
+ - relationship_summary: explanation as to why you think the source entity and the target entity are related to each other
21
+ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_summary>)
22
+
23
+ 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
24
+ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
25
+
26
+ 4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
27
+
28
+ 5. When finished, output {completion_delimiter}
29
+
30
+ ################
31
+ -Examples-
32
+ ################
33
+ -Example 1-
34
+ Text:
35
+ ################
36
+ In the second century of the Christian Era, the empire of Rome comprehended the fairest part of the earth, and the most civilized portion of mankind. The frontiers of that extensive monarchy were guarded by ancient renown and disciplined valor. The gentle but powerful influence of laws and manners had gradually cemented the union of the provinces. Their peaceful inhabitants enjoyed and abused the advantages of wealth and luxury. The image of a free constitution was preserved with decent reverence: the Roman senate appeared to possess the sovereign authority, and devolved on the emperors all the executive powers of government. During a happy period of more than fourscore years, the public administration was conducted by the virtue and abilities of Nerva, Trajan, Hadrian, and the two Antonines.
37
+ ################
38
+ Output:
39
+ ("entity"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"organization"{tuple_delimiter}"The dominant empire of the second century CE, encompassing the most developed regions of the known world."){record_delimiter}
40
+ ("entity"{tuple_delimiter}"Second Century CE"{tuple_delimiter}"date"{tuple_delimiter}"Time period of the Christian Era when the Roman Empire was at its height."){record_delimiter}
41
+ ("entity"{tuple_delimiter}"Rome"{tuple_delimiter}"location"{tuple_delimiter}"The capital and heart of the Roman Empire."){record_delimiter}
42
+ ("entity"{tuple_delimiter}"Roman Senate"{tuple_delimiter}"organization"{tuple_delimiter}"Legislative body that appeared to hold sovereign authority in Rome."){record_delimiter}
43
+ ("entity"{tuple_delimiter}"Nerva"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor who contributed to the public administration during a prosperous period."){record_delimiter}
44
+ ("entity"{tuple_delimiter}"Trajan"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor known for his virtue and administrative abilities."){record_delimiter}
45
+ ("entity"{tuple_delimiter}"Hadrian"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor who governed during the empire's peaceful period."){record_delimiter}
46
+ ("entity"{tuple_delimiter}"Antonines"{tuple_delimiter}"person"{tuple_delimiter}"Two Roman emperors who ruled during a period of prosperity and good governance."){record_delimiter}
47
+ ("entity"{tuple_delimiter}"Roman Law"{tuple_delimiter}"concept"{tuple_delimiter}"System of laws and manners that unified the provinces of the Roman Empire."){record_delimiter}
48
+ ("relationship"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Roman Law"{tuple_delimiter}"The empire was unified and maintained through the influence of its laws and customs."){record_delimiter}
49
+ ("relationship"{tuple_delimiter}"Roman Senate"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"The Senate appeared to possess sovereign authority while delegating executive powers to emperors."){record_delimiter}
50
+ ("relationship"{tuple_delimiter}"Nerva"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Nerva was one of the emperors who contributed to the empire's successful administration."){record_delimiter}
51
+ ("relationship"{tuple_delimiter}"Trajan"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Trajan was one of the emperors who governed during the empire's prosperous period."){record_delimiter}
52
+ ("relationship"{tuple_delimiter}"Hadrian"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Hadrian was one of the emperors who managed the empire's administration effectively."){record_delimiter}
53
+ ("relationship"{tuple_delimiter}"Antonines"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"The Antonines were emperors who helped maintain the empire's prosperity through their governance."){record_delimiter}
54
+ ("content_keywords"{tuple_delimiter}"Roman governance, imperial prosperity, law and order, civilized society"){completion_delimiter}
55
+
56
+ -Example 2-
57
+ Text:
58
+ #############
59
+ Overall, the analysis of the OsDT11 sequence demonstrated that this protein belongs to the CRP family. Since OsDT11 is predicted to be a secreted protein, the subcellular localization of OsDT11 was determined by fusing the OsDT11 ORF to RFP in a p35S::RFP vector by in vivo protein targeting in NB epidermal cells by performing an Agrobacterium tumefaciens-mediated transient assay. After incubation for 48 h, the RFP signals were mainly detected in the cell-wall of OsDT11-RFP transformed cells, while the control cells (transformed with the RFP construct) displayed ubiquitous RFP signals, demonstrating that OsDT11 is a secreted signal peptide. Moreover, when the infiltrated leaf sections were plasmolyzed, the OsDT11-RFP fusion proteins were located on the cell wall.
60
+ #############
61
+ Output:
62
+ ("entity"{tuple_delimiter}"OsDT11"{tuple_delimiter}"gene"{tuple_delimiter}"A protein sequence belonging to the CRP family, demonstrated to be a secreted signal peptide that localizes to cell walls."){record_delimiter}
63
+ ("entity"{tuple_delimiter}"CRP family"{tuple_delimiter}"science"{tuple_delimiter}"A protein family to which OsDT11 belongs, characterized by specific structural and functional properties."){record_delimiter}
64
+ ("entity"{tuple_delimiter}"RFP"{tuple_delimiter}"technology"{tuple_delimiter}"Red Fluorescent Protein, used as a fusion marker to track protein localization in cells."){record_delimiter}
65
+ ("entity"{tuple_delimiter}"p35S::RFP vector"{tuple_delimiter}"technology"{tuple_delimiter}"A genetic construct used for protein expression and visualization studies, containing the 35S promoter and RFP marker."){record_delimiter}
66
+ ("entity"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"nature"{tuple_delimiter}"Plant epidermal cells used as the experimental system for protein localization studies."){record_delimiter}
67
+ ("entity"{tuple_delimiter}"Agrobacterium tumefaciens"{tuple_delimiter}"nature"{tuple_delimiter}"A bacteria species used for transferring genetic material into plant cells in laboratory experiments."){record_delimiter}
68
+ ("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"CRP family"{tuple_delimiter}"OsDT11 is identified as a member of the CRP family through sequence analysis."){record_delimiter}
69
+ ("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"RFP"{tuple_delimiter}"OsDT11 was fused to RFP to study its cellular localization."){record_delimiter}
70
+ ("relationship"{tuple_delimiter}"Agrobacterium tumefaciens"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"Agrobacterium tumefaciens was used to transfer genetic material into NB epidermal cells through a transient assay."){record_delimiter}
71
+ ("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"OsDT11's subcellular localization was studied in NB epidermal cells, showing cell wall targeting."){record_delimiter}
72
+ ("content_keywords"{tuple_delimiter}"protein localization, gene expression, cellular biology, molecular techniques"){completion_delimiter}
73
+
74
+ ################
75
+ -Real Data-
76
+ ################
77
+ Entity_types: {entity_types}
78
+ Text: {input_text}
79
+ ################
80
+ Output:
81
+ """
82
+
83
+
84
+ TEMPLATE_ZH: str = """你是一个NLP专家,擅长分析文本提取命名实体和关系。
85
+
86
+ -目标-
87
+ 给定一个实体类型列表和可能与列表相关的文本,从文本中识别所有这些类型的实体,以及这些实体之间所有的关系。
88
+ 使用{language}作为输出语言。
89
+
90
+ -步骤-
91
+ 1. 识别所有实体。对于每个识别的实体,提取以下信息:
92
+ - entity_name:实体的名称,首字母大写
93
+ - entity_type:以下类型之一:[{entity_types}]
94
+ - entity_summary:实体的属性与活动的全面总结
95
+ 将每个实体格式化为("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_summary>)
96
+
97
+ 2. 从步骤1中识别的实体中,识别所有(源实体,目标实体)对,这些实体彼此之间*明显相关*。
98
+ 对于每对相关的实体,提取以下信息:
99
+ - source_entity:步骤1中识别的源实体名称
100
+ - target_entity:步骤1中识别的目标实体名称
101
+ - relationship_summary:解释为什么你认为源实体和目标实体彼此相关
102
+ 将每个关系格式化为("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_summary>)
103
+
104
+ 3. 识别总结整个文本的主要概念、主题或话题的高级关键词。这些应该捕捉文档中存在的总体思想。
105
+ 将内容级关键词格式化为("content_keywords"{tuple_delimiter}<high_level_keywords>)
106
+
107
+ 4. 以中文返回步骤1和2中识别出的所有实体和关系的输出列表。使用**{record_delimiter}**作为列表分隔符。
108
+
109
+ 5. 完成后,输出{completion_delimiter}
110
+
111
+ ################
112
+ -示例-
113
+ ################
114
+ -示例 1-
115
+ 文本:
116
+ ################
117
+ 鲁镇的酒店的格局,是和别处不同的:都是当街一个曲尺形的大柜台,柜里面预备着热水,可以随时温酒。做工的人,傍午傍晚散了工,每每花四文铜钱,买一碗酒,——这是二十多年前的事,现在每碗要涨到十文,——靠柜外站着,热热的喝了休息;倘肯多花一文,便可以买一碟盐煮笋,或者茴香豆,做下酒物了,如果出到十几文,那就能买一样荤菜,但这些顾客,多是短衣帮,大抵没有这样阔绰。只有穿长衫的,才踱进店面隔壁的房子里,要酒要菜,慢慢地坐喝。
118
+ ################
119
+ 输出:
120
+ ("entity"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"location"{tuple_delimiter}"鲁镇的酒店是一个特定地点,其格局独特,柜台形状为曲尺形,提供热水温酒服务。"){record_delimiter}
121
+ ("entity"{tuple_delimiter}"曲尺形的大柜台"{tuple_delimiter}"keyword"{tuple_delimiter}"曲尺形的大柜台是鲁镇酒店内独特的设施,用于提供服务。"){record_delimiter}
122
+ ("entity"{tuple_delimiter}"热水温酒"{tuple_delimiter}"keyword"{tuple_delimiter}"热水温酒是鲁镇酒店提供的一项服务,顾客可以随时温酒。"){record_delimiter}
123
+ ("entity"{tuple_delimiter}"做工的人"{tuple_delimiter}"person"{tuple_delimiter}"做工的人是鲁镇酒店的常客,通常在工作结束后花四文铜钱买一碗酒,有时还会买一些下酒菜。"){record_delimiter}
124
+ ("entity"{tuple_delimiter}"二十多年前的事"{tuple_delimiter}"date"{tuple_delimiter}"二十多年前的事是指过去的时间点,当时一碗酒的价格为四文铜钱。"){record_delimiter}
125
+ ("entity"{tuple_delimiter}"现在"{tuple_delimiter}"date"{tuple_delimiter}"现在是指当前的时间点,与过去相比,一碗酒的价格涨到了十文。"){record_delimiter}
126
+ ("entity"{tuple_delimiter}"短衣帮"{tuple_delimiter}"concept"{tuple_delimiter}"短衣帮是指做工的人,他们通常穿着短衣,经济条件有限。"){record_delimiter}
127
+ ("entity"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"person"{tuple_delimiter}"穿长衫的是鲁镇酒店的另一类顾客,他们经济条件较好,通常会进入店面隔壁的房间慢慢喝酒吃菜。"){record_delimiter}
128
+ ("entity"{tuple_delimiter}"盐煮笋"{tuple_delimiter}"food"{tuple_delimiter}"盐煮笋是鲁镇酒店提供的一种下酒菜,顾客可以花一文铜钱购买。"){record_delimiter}
129
+ ("entity"{tuple_delimiter}"茴香豆"{tuple_delimiter}"food"{tuple_delimiter}"茴香豆是鲁镇酒店提供的另一种下酒菜,顾客可以花一文铜钱购买。"){record_delimiter}
130
+ ("entity"{tuple_delimiter}"荤菜"{tuple_delimiter}"food"{tuple_delimiter}"荤菜是鲁镇酒店提供的较为昂贵的菜品,顾客需要花十几文铜钱购买。"){record_delimiter}
131
+ ("relationship"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"曲尺形的大柜台"{tuple_delimiter}"鲁镇的酒店内设有一个曲尺形的大柜台,用于提供服务。"){record_delimiter}
132
+ ("relationship"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"热水温酒"{tuple_delimiter}"鲁镇的酒店提供热水温酒服务,顾客可以随时温酒。"){record_delimiter}
133
+ ("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"二十多年前的事"{tuple_delimiter}"做工的人在二十多年前花四文铜钱买一碗酒,反映了当时的生活成本。"){record_delimiter}
134
+ ("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"现在"{tuple_delimiter}"现在做工的人需要花十文铜钱买一碗酒,反映了物价的上涨。"){record_delimiter}
135
+ ("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"短衣帮"{tuple_delimiter}"做工的人属于短衣帮,通常经济条件有限。"){record_delimiter}
136
+ ("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"做工的人与穿长衫的形成对比,反映了社会阶层的差异。"){record_delimiter}
137
+ ("relationship"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"穿长衫的顾客通常会进入鲁镇酒店的房间慢慢喝酒吃菜,享受更高级的服务。"){record_delimiter}
138
+ ("content_keywords"{tuple_delimiter}"社会分层, 经济差距, 服务, 生活成本, 历史背景"){completion_delimiter}
139
+
140
+ -示例 2-
141
+ 文本:
142
+ ################
143
+ 黄华占是感温型常规稻品种,2016—2017 年在铅山县汪二镇作中稻示范种植综合表现优良。结合示范情况,对黄华��的特征特性作简单总结,在此基础上提出高产栽培技术,以期为该品种的推广种植提供参考。近年来,铅山县粮食生产紧紧围绕“稳产、优质、增效”的总体要求、大力实施优质稻推广,积极引导粮食生产由增产转向提质。我国杂交水稻技术世界领先、优质稻品种众多,在市场走势方面(尤其稻米行情清淡期),人们习惯性地北涨看长粒香、南涨看黄华占。黄华占是广东省农业科学院水稻研究所以黄新占/丰华占为亲本选育而成,分别通过粤、湘、鄂、浙、桂、琼等省审定。为了更好、更快地推广黄华占水稻,铅山县分别于2016 年、2017 年在汪二镇火田村试验示范种植黄华占近 5.87 hm^2 ,综合表现优良。现将黄华占水稻的特征特性及高产栽培技术介绍如下。
144
+ ################
145
+ 输出:
146
+ ("entity"{tuple_delimiter}"黄华占"{tuple_delimiter}"work"{tuple_delimiter}"黄华占是一种感温型常规稻品种,由广东省农业科学院水稻研究所选育,通过多个省份审定,2016-2017年在铅山县汪二镇进行示范种植,表现优良。"){record_delimiter}
147
+ ("entity"{tuple_delimiter}"2016—2017年"{tuple_delimiter}"date"{tuple_delimiter}"2016—2017年是黄华占在铅山县汪二镇进行示范种植的时间段。"){record_delimiter}
148
+ ("entity"{tuple_delimiter}"铅山县"{tuple_delimiter}"location"{tuple_delimiter}"铅山县位于中国江西省,是黄华占水稻示范种植的地点之一。"){record_delimiter}
149
+ ("entity"{tuple_delimiter}"汪二镇"{tuple_delimiter}"location"{tuple_delimiter}"汪二镇是铅山县的一个镇,2016-2017年在此进行了黄华占水稻的示范种植。"){record_delimiter}
150
+ ("entity"{tuple_delimiter}"火田村"{tuple_delimiter}"location"{tuple_delimiter}"火田村是汪二镇的一个村庄,2016-2017年在此进行了黄华占水稻的试验示范种植。"){record_delimiter}
151
+ ("entity"{tuple_delimiter}"广东省农业科学院水稻研究所"{tuple_delimiter}"organization"{tuple_delimiter}"广东省农业科学院水稻研究所是中国的一个科研机构,负责黄华占水稻的选育工作。"){record_delimiter}
152
+ ("entity"{tuple_delimiter}"黄新占/丰华占"{tuple_delimiter}"work"{tuple_delimiter}"黄新占和丰华占是黄华占水稻的亲本,用于选育黄华占。"){record_delimiter}
153
+ ("entity"{tuple_delimiter}"粤、湘、鄂、浙、桂、琼等省"{tuple_delimiter}"location"{tuple_delimiter}"这些省份通过了黄华占水稻的审定,表明该品种在这些地区具有良好的适应性和推广潜力。"){record_delimiter}
154
+ ("entity"{tuple_delimiter}"高产栽培技术"{tuple_delimiter}"technology"{tuple_delimiter}"高产栽培技术是指为了提高黄华占水稻产量而采用的一系列农业技术措施。"){record_delimiter}
155
+ ("entity"{tuple_delimiter}"稳产、优质、增效"{tuple_delimiter}"concept"{tuple_delimiter}"这是铅山县粮食生产的主要目标,强调了粮食生产的稳定、质量和效益。"){record_delimiter}
156
+ ("entity"{tuple_delimiter}"优质稻推广"{tuple_delimiter}"mission"{tuple_delimiter}"优质稻推广是铅山县粮食生产的一个重要任务,旨在提高稻米的质量和市场竞争力。"){record_delimiter}
157
+ ("entity"{tuple_delimiter}"杂交水稻技术"{tuple_delimiter}"technology"{tuple_delimiter}"杂交水稻技术是中国领先的世界级农业技术,用于提高水稻的产量和质量。"){record_delimiter}
158
+ ("entity"{tuple_delimiter}"北涨看长粒香、南涨看黄华占"{tuple_delimiter}"concept"{tuple_delimiter}"这是市场对不同地区优质稻品种的习惯性关注点,北方面对长粒香,南方面对黄华占。"){record_delimiter}
159
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"2016—2017年"{tuple_delimiter}"黄华占在2016—2017年期间在铅山县进行了示范种植,展示了其优良的特性。"){record_delimiter}
160
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"铅山县"{tuple_delimiter}"黄华占在铅山县进行了示范种植,表现出了优良的适应性和产量。"){record_delimiter}
161
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"汪二镇"{tuple_delimiter}"黄华占在汪二镇进行了示范种植,这是其在铅山县示范种植的一部分。"){record_delimiter}
162
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"火田村"{tuple_delimiter}"黄华占在火田村进行了试验示范种植,这是其在汪二镇示范种植的一部分。"){record_delimiter}
163
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"广东省农业科学院水稻研究所"{tuple_delimiter}"黄华占是由广东省农业科学院水稻研究所选育的,该研究所负责其研发工作。"){record_delimiter}
164
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"黄新占/丰华占"{tuple_delimiter}"黄华占的亲本是黄新占和丰华占,这些亲本用于选育黄华占。"){record_delimiter}
165
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"粤、湘、��、浙、桂、琼等省"{tuple_delimiter}"黄华占通过了这些省份的审定,表明其在这些地区的适应性和推广潜力。"){record_delimiter}
166
+ ("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"高产栽培技术"{tuple_delimiter}"高产栽培技术是为了提高黄华占水稻产量而开发的技术措施。"){record_delimiter}
167
+ ("relationship"{tuple_delimiter}"铅山县"{tuple_delimiter}"稳产、优质、增效"{tuple_delimiter}"铅山县的粮食生产目标是稳产、优质、增效,这些目标指导了黄华占的示范种植。"){record_delimiter}
168
+ ("relationship"{tuple_delimiter}"铅山县"{tuple_delimiter}"优质稻推广"{tuple_delimiter}"铅山县实施了优质稻推广计划,黄华占是该计划的一部分。"){record_delimiter}
169
+ ("relationship"{tuple_delimiter}"杂交水稻技术"{tuple_delimiter}"北涨看长粒香、南涨看黄华占"{tuple_delimiter}"杂交水稻技术的发展使得黄华占等优质稻品种在市场中受到关注。"){record_delimiter}
170
+ ("content_keywords"{tuple_delimiter}"黄华占, 水稻种植, 高产栽培技术, 优质稻推广, 地区适应性, 市场趋势, 技术影响"){completion_delimiter}
171
+
172
+ -真实数据-
173
+ 实体类型:{entity_types}
174
+ 文本:{input_text}
175
+ ################
176
+ 输出:
177
+ """
178
+
179
+ CONTINUE_EN: str = """MANY entities and relationships were missed in the last extraction. \
180
+ Add them below using the same format:
181
+ """
182
+
183
+ CONTINUE_ZH: str = """很多实体和关系在上一次的提取中可能被遗漏了。请在下面使用相同的格式添加它们:"""
184
+
185
+ IF_LOOP_EN: str = """It appears some entities and relationships may have still been missed. \
186
+ Answer YES | NO if there are still entities and relationships that need to be added.
187
+ """
188
+
189
+ IF_LOOP_ZH: str = """看起来可能仍然遗漏了一些实体和关系。如果仍有实体和关系需要添加,请回答YES | NO。"""
190
+
191
+ KG_EXTRACTION_PROMPT: dict = {
192
+ "English": {
193
+ "TEMPLATE": TEMPLATE_EN,
194
+ "CONTINUE": CONTINUE_EN,
195
+ "IF_LOOP": IF_LOOP_EN,
196
+ },
197
+ "Chinese": {
198
+ "TEMPLATE": TEMPLATE_ZH,
199
+ "CONTINUE": CONTINUE_ZH,
200
+ "IF_LOOP": IF_LOOP_ZH,
201
+ },
202
+ "FORMAT": {
203
+ "tuple_delimiter": "<|>",
204
+ "record_delimiter": "##",
205
+ "completion_delimiter": "<|COMPLETE|>",
206
+ "entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
207
+ science, technology, mission, gene",
208
+ "language": "English",
209
+ },
210
+ }