tt-dart commited on
Commit
d834d9d
·
1 Parent(s): e91a58c

update readme

Browse files
Files changed (34) hide show
  1. .gitignore +2 -0
  2. {finetune → NL2HLTLTranslator}/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/data-00000-of-00001.arrow +0 -0
  3. NL2HLTLTranslator/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json +3 -0
  4. NL2HLTLTranslator/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/state.json +3 -0
  5. {finetune → NL2HLTLTranslator}/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/data-00000-of-00001.arrow +0 -0
  6. NL2HLTLTranslator/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json +3 -0
  7. NL2HLTLTranslator/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/state.json +3 -0
  8. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_dp2_patch.py +0 -0
  9. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune.py +0 -0
  10. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune_ver2.py +0 -0
  11. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune_ver3_qlora.py +0 -0
  12. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_test.py +1 -1
  13. {finetune → NL2HLTLTranslator}/Llama2_13b/llama_test.py +0 -0
  14. {finetune → NL2HLTLTranslator}/MIT_NL2TL/NL2TL.py +0 -0
  15. {finetune → NL2HLTLTranslator}/T5_XXL/t5_lora_evaluate.py +0 -0
  16. {finetune → NL2HLTLTranslator}/T5_XXL/t5_lora_fintune.py +0 -0
  17. {finetune → NL2HLTLTranslator}/T5_XXL/t5_realtime_evaluate.py +0 -0
  18. {finetune → NL2HLTLTranslator}/__init__.py +0 -0
  19. {finetune → NL2HLTLTranslator}/data_augmentation/GPTbasedAug.py +0 -0
  20. {finetune → NL2HLTLTranslator}/data_augmentation/dataset_creator.py +0 -0
  21. NL2HLTLTranslator/fastapi_server.py +398 -0
  22. {finetune → NL2HLTLTranslator}/mistral7b/finetune.py +2 -2
  23. {finetune → NL2HLTLTranslator}/mistral7b/prediction.py +6 -6
  24. {finetune → NL2HLTLTranslator}/mistral7b/test.py +0 -0
  25. {finetune → NL2HLTLTranslator}/realtime_run.py +0 -0
  26. {finetune → NL2HLTLTranslator}/test.py +0 -0
  27. NL2HLTLTranslator/utils/util.py +449 -0
  28. NL2TL-dataset/collect2/getUniqueLTL.py +2 -2
  29. README.md +13 -1
  30. finetune/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json +0 -69
  31. finetune/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/state.json +0 -13
  32. finetune/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json +0 -69
  33. finetune/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/state.json +0 -13
  34. setup.py +12 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.egg-info
2
+ *.pyc
{finetune → NL2HLTLTranslator}/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/data-00000-of-00001.arrow RENAMED
File without changes
NL2HLTLTranslator/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abd1c5e1bcff098b3f091b0b6e7161643fe38547c1b323eba17dd050461845c5
3
+ size 1370
NL2HLTLTranslator/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/state.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6caef2d70554a357d04d62d5e8a5c1f2ead2c9eb245f2688f1ea714b858e5a95
3
+ size 249
{finetune → NL2HLTLTranslator}/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/data-00000-of-00001.arrow RENAMED
File without changes
NL2HLTLTranslator/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abd1c5e1bcff098b3f091b0b6e7161643fe38547c1b323eba17dd050461845c5
3
+ size 1370
NL2HLTLTranslator/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/state.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45e111fa90561c16eaea253bbb292557a4b7c17acb1f152096af221bc775ddcb
3
+ size 250
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_dp2_patch.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune_ver2.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_fintune_ver3_qlora.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_lora_test.py RENAMED
@@ -22,7 +22,7 @@ os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
22
 
23
  class Llama_NL2TL_translator():
24
  def __init__(self,
25
- output_dir = "/home/icl-mill19/xsj/model_weight",
26
  tuned_model_name="llama2_13b__mid_asciiaug1",
27
  # CUDA_device='0',
28
  quat=True) -> None:
 
22
 
23
  class Llama_NL2TL_translator():
24
  def __init__(self,
25
+ output_dir = "path/to/model_weight",
26
  tuned_model_name="llama2_13b__mid_asciiaug1",
27
  # CUDA_device='0',
28
  quat=True) -> None:
{finetune → NL2HLTLTranslator}/Llama2_13b/llama_test.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/MIT_NL2TL/NL2TL.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/T5_XXL/t5_lora_evaluate.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/T5_XXL/t5_lora_fintune.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/T5_XXL/t5_realtime_evaluate.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/__init__.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/data_augmentation/GPTbasedAug.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/data_augmentation/dataset_creator.py RENAMED
File without changes
NL2HLTLTranslator/fastapi_server.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modified by xsj
3
+ This script implements an API for the ChatGLM3-6B model,
4
+ formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).
5
+ It's designed to be run as a web server using FastAPI and uvicorn,
6
+ making the ChatGLM3-6B model accessible through OpenAI Client.
7
+
8
+ Key Components and Features:
9
+ - Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.
10
+ - FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.
11
+ - API Endpoints:
12
+ - "/v1/models": Lists the available models, specifically ChatGLM3-6B.
13
+ - "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses.
14
+ - "/v1/embeddings": Processes Embedding request of a list of text inputs.
15
+ - Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.
16
+ For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output
17
+ that many tokens after accounting for the history and prompt tokens.
18
+ - Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.
19
+ - Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.
20
+ - Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.
21
+
22
+ Note:
23
+ This script doesn't include the setup for special tokens or multi-GPU support by default.
24
+ Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.
25
+ Embedding Models only support in One GPU.
26
+
27
+ """
28
+
29
+ import os
30
+ import time
31
+ import tiktoken
32
+ import torch
33
+ import uvicorn
34
+
35
+ from fastapi import FastAPI, HTTPException, Response
36
+ from fastapi.middleware.cors import CORSMiddleware
37
+
38
+ from contextlib import asynccontextmanager
39
+ from typing import List, Literal, Optional, Union
40
+ from loguru import logger
41
+ from pydantic import BaseModel, Field
42
+ from transformers import AutoTokenizer, AutoModel
43
+ # from utils import process_response, generate_chatglm3, generate_stream_chatglm3
44
+ from sentence_transformers import SentenceTransformer
45
+
46
+ from sse_starlette.sse import EventSourceResponse
47
+
48
+
49
+ # from NL2HLTLtaskPlanner.finetune.Llama2_13b.llama_lora_test import Llama_NL2TL_translator as NL2TL_translator
50
+ from NL2HLTLTranslator.mistral7b.prediction import Mistral_NL2TL_translator as NL2TL_translator
51
+ # Set up limit request time
52
+ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
53
+ # set LLM path
54
+ output_dir = os.path.join(os.path.dirname(__file__),"../")
55
+ tuned_model_name="mistral7b_quat8"
56
+
57
+
58
+ @asynccontextmanager
59
+ async def lifespan(app: FastAPI):
60
+ yield
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
+ torch.cuda.ipc_collect()
64
+
65
+
66
+ app = FastAPI(lifespan=lifespan)
67
+
68
+ app.add_middleware(
69
+ CORSMiddleware,
70
+ allow_origins=["*"],
71
+ allow_credentials=True,
72
+ allow_methods=["*"],
73
+ allow_headers=["*"],
74
+ )
75
+
76
+
77
+ class ModelCard(BaseModel):
78
+ id: str
79
+ object: str = "model"
80
+ created: int = Field(default_factory=lambda: int(time.time()))
81
+ owned_by: str = "owner"
82
+ root: Optional[str] = None
83
+ parent: Optional[str] = None
84
+ permission: Optional[list] = None
85
+
86
+
87
+ class ModelList(BaseModel):
88
+ object: str = "list"
89
+ data: List[ModelCard] = []
90
+
91
+
92
+ class FunctionCallResponse(BaseModel):
93
+ name: Optional[str] = None
94
+ arguments: Optional[str] = None
95
+
96
+
97
+ class ChatMessage(BaseModel):
98
+ role: Literal["user", "assistant", "system", "function"]
99
+ content: str = None
100
+ name: Optional[str] = None
101
+ function_call: Optional[FunctionCallResponse] = None
102
+
103
+
104
+ class DeltaMessage(BaseModel):
105
+ role: Optional[Literal["user", "assistant", "system"]] = None
106
+ content: Optional[str] = None
107
+ function_call: Optional[FunctionCallResponse] = None
108
+
109
+
110
+ ## for Embedding
111
+ class EmbeddingRequest(BaseModel):
112
+ input: List[str]
113
+ model: str
114
+
115
+
116
+ class CompletionUsage(BaseModel):
117
+ prompt_tokens: int
118
+ completion_tokens: int
119
+ total_tokens: int
120
+
121
+
122
+ class EmbeddingResponse(BaseModel):
123
+ data: list
124
+ model: str
125
+ object: str
126
+ usage: CompletionUsage
127
+
128
+
129
+ # for ChatCompletionRequest
130
+
131
+ class UsageInfo(BaseModel):
132
+ prompt_tokens: int = 0
133
+ total_tokens: int = 0
134
+ completion_tokens: Optional[int] = 0
135
+
136
+
137
+ class ChatCompletionRequest(BaseModel):
138
+ model: str
139
+ messages: List[ChatMessage]
140
+ temperature: Optional[float] = 0.8
141
+ top_p: Optional[float] = 0.8
142
+ max_tokens: Optional[int] = None
143
+ stream: Optional[bool] = False
144
+ tools: Optional[Union[dict, List[dict]]] = None
145
+ repetition_penalty: Optional[float] = 1.1
146
+
147
+
148
+ class ChatCompletionResponseChoice(BaseModel):
149
+ index: int
150
+ message: ChatMessage
151
+ finish_reason: Literal["stop", "length", "function_call"]
152
+
153
+
154
+ class ChatCompletionResponseStreamChoice(BaseModel):
155
+ delta: DeltaMessage
156
+ finish_reason: Optional[Literal["stop", "length", "function_call"]]
157
+ index: int
158
+
159
+
160
+ class ChatCompletionResponse(BaseModel):
161
+ model: str
162
+ id: str
163
+ object: Literal["chat.completion", "chat.completion.chunk"]
164
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
165
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
166
+ usage: Optional[UsageInfo] = None
167
+
168
+
169
+ @app.get("/health")
170
+ async def health() -> Response:
171
+ """Health check."""
172
+ return Response(status_code=200)
173
+
174
+
175
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
176
+ async def get_embeddings(request: EmbeddingRequest):
177
+ embeddings = [embedding_model.encode(text) for text in request.input]
178
+ embeddings = [embedding.tolist() for embedding in embeddings]
179
+
180
+ def num_tokens_from_string(string: str) -> int:
181
+ """
182
+ Returns the number of tokens in a text string.
183
+ use cl100k_base tokenizer
184
+ """
185
+ encoding = tiktoken.get_encoding('cl100k_base')
186
+ num_tokens = len(encoding.encode(string))
187
+ return num_tokens
188
+
189
+ response = {
190
+ "data": [
191
+ {
192
+ "object": "embedding",
193
+ "embedding": embedding,
194
+ "index": index
195
+ }
196
+ for index, embedding in enumerate(embeddings)
197
+ ],
198
+ "model": request.model,
199
+ "object": "list",
200
+ "usage": CompletionUsage(
201
+ prompt_tokens=sum(len(text.split()) for text in request.input),
202
+ completion_tokens=0,
203
+ total_tokens=sum(num_tokens_from_string(text) for text in request.input),
204
+ )
205
+ }
206
+ return response
207
+
208
+
209
+ @app.get("/v1/models", response_model=ModelList)
210
+ async def list_models():
211
+ model_card = ModelCard(
212
+ id="chatglm3-6b"
213
+ )
214
+ return ModelList(
215
+ data=[model_card]
216
+ )
217
+
218
+ count=0
219
+
220
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
221
+ async def create_chat_completion(request: ChatCompletionRequest):
222
+ global model, tokenizer, LLM
223
+
224
+ if len(request.messages) < 1 or request.messages[-1].role == "assistant":
225
+ raise HTTPException(status_code=400, detail="Invalid request")
226
+
227
+ gen_params = dict(
228
+ messages=request.messages,
229
+ temperature=request.temperature,
230
+ top_p=request.top_p,
231
+ max_tokens=request.max_tokens or 1024,
232
+ echo=False,
233
+ stream=request.stream,
234
+ repetition_penalty=request.repetition_penalty,
235
+ tools=request.tools,
236
+ )
237
+ logger.debug(f"==== request ====\n{gen_params}")
238
+
239
+ # if request.stream:
240
+
241
+ # # Use the stream mode to read the first few characters, if it is not a function call, direct stram output
242
+ # predict_stream_generator = predict_stream(request.model, gen_params)
243
+ # output = next(predict_stream_generator)
244
+ # if not contains_custom_function(output):
245
+ # return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
246
+
247
+ # # Obtain the result directly at one time and determine whether tools needs to be called.
248
+ # logger.debug(f"First result output:\n{output}")
249
+
250
+ # function_call = None
251
+ # if output and request.tools:
252
+ # try:
253
+ # function_call = process_response(output, use_tool=True)
254
+ # except:
255
+ # logger.warning("Failed to parse tool call")
256
+
257
+ # # CallFunction
258
+ # if isinstance(function_call, dict):
259
+ # function_call = FunctionCallResponse(**function_call)
260
+
261
+ # """
262
+ # In this demo, we did not register any tools.
263
+ # You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.
264
+ # Similar to the following method:
265
+ # function_args = json.loads(function_call.arguments)
266
+ # tool_response = dispatch_tool(tool_name: str, tool_params: dict)
267
+ # """
268
+ # tool_response = ""
269
+
270
+ # if not gen_params.get("messages"):
271
+ # gen_params["messages"] = []
272
+
273
+ # gen_params["messages"].append(ChatMessage(
274
+ # role="assistant",
275
+ # content=output,
276
+ # ))
277
+ # gen_params["messages"].append(ChatMessage(
278
+ # role="function",
279
+ # name=function_call.name,
280
+ # content=tool_response,
281
+ # ))
282
+
283
+ # # Streaming output of results after function calls
284
+ # generate = predict(request.model, gen_params)
285
+ # return EventSourceResponse(generate, media_type="text/event-stream")
286
+
287
+ # else:
288
+ # # Handled to avoid exceptions in the above parsing function process.
289
+ # generate = parse_output_text(request.model, output)
290
+ # return EventSourceResponse(generate, media_type="text/event-stream")
291
+
292
+ # Here is the handling of stream = False
293
+ # print("gen_params['messages'][0].content",gen_params['messages'][0].content)
294
+ response=LLM.translate(gen_params['messages'][0].content)
295
+ # print('response',response)
296
+ # return response
297
+ # response = generate_chatglm3(model, tokenizer, gen_params)
298
+
299
+ # # Remove the first newline character
300
+ # if response["text"].startswith("\n"):
301
+ # response["text"] = response["text"][1:]
302
+ # response["text"] = response["text"].strip()
303
+
304
+ usage = UsageInfo()
305
+ # function_call, finish_reason = None, "stop"
306
+ # if request.tools:
307
+ # try:
308
+ # function_call = process_response(response["text"], use_tool=True)
309
+ # except:
310
+ # logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")
311
+
312
+ # if isinstance(function_call, dict):
313
+ # finish_reason = "function_call"
314
+ # function_call = FunctionCallResponse(**function_call)
315
+ function_call = None
316
+
317
+ message = ChatMessage(
318
+ role="assistant",
319
+ content=response,
320
+ function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
321
+ )
322
+
323
+ logger.debug(f"==== message ====\n{message}")
324
+
325
+ choice_data = ChatCompletionResponseChoice(
326
+ index=0,
327
+ message=message,
328
+ finish_reason='stop',
329
+ )
330
+ # task_usage = UsageInfo.model_validate(response["usage"])
331
+ # for usage_key, usage_value in task_usage.model_dump().items():
332
+ # setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
333
+ # count+=1
334
+ return ChatCompletionResponse(
335
+ model=request.model,
336
+ id="", # for open_source model, id is empty
337
+ choices=[choice_data],
338
+ object="chat.completion",
339
+ usage=usage
340
+ )
341
+
342
+
343
+ async def parse_output_text(model_id: str, value: str):
344
+ """
345
+ Directly output the text content of value
346
+
347
+ :param model_id:
348
+ :param value:
349
+ :return:
350
+ """
351
+ choice_data = ChatCompletionResponseStreamChoice(
352
+ index=0,
353
+ delta=DeltaMessage(role="assistant", content=value),
354
+ finish_reason=None
355
+ )
356
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
357
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
358
+
359
+ choice_data = ChatCompletionResponseStreamChoice(
360
+ index=0,
361
+ delta=DeltaMessage(),
362
+ finish_reason="stop"
363
+ )
364
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
365
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
366
+ yield '[DONE]'
367
+
368
+
369
+ def contains_custom_function(value: str) -> bool:
370
+ """
371
+ Determine whether 'function_call' according to a special function prefix.
372
+
373
+ For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_"
374
+
375
+ [Note] This is not a rigorous judgment method, only for reference.
376
+
377
+ :param value:
378
+ :return:
379
+ """
380
+ return value and 'get_' in value
381
+
382
+ def run(output_dir = "path/to/model_weight", tuned_model_name="llama2_13b__mid_asciiaug1",CUDA_device='0',quat=True):
383
+ global LLM
384
+ LLM=NL2TL_translator(output_dir=output_dir,tuned_model_name= tuned_model_name,quat=quat)
385
+
386
+ tokenizer = LLM.tokenizer
387
+ model = LLM.model
388
+
389
+ # load Embedding
390
+ # embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
391
+ uvicorn.run(app, host='0.0.0.0', port=8001, workers=1)
392
+ if __name__ == "__main__":
393
+ # Load LLM
394
+ # on alinware mill19
395
+ # run()
396
+ # on icl-superman
397
+ run(output_dir=output_dir,tuned_model_name=tuned_model_name)
398
+ # on zju server
{finetune → NL2HLTLTranslator}/mistral7b/finetune.py RENAMED
@@ -32,8 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
32
  # dataset = load_dataset("samsum")
33
  device='cuda'
34
  np.random.seed(42)
35
- output_dir = "/home/user/xsj/model_weight/"
36
- datapath='/home/user/xsj/NL2TL-dataset/collect2'
37
  exp_name="_mid_ascii_0327_eos_2"
38
  explainer_files=['LTLexplain_0.json','LTLexplain_1.json','LTLexplain_2.json','LTLexplain_3.json']
39
  explainer_dic={}
 
32
  # dataset = load_dataset("samsum")
33
  device='cuda'
34
  np.random.seed(42)
35
+ output_dir = os.path.join(os.path.dirname(__file__),'../')
36
+ datapath=os.path.join(os.path.dirname(__file__),'../NL2TL-dataset/collect2')
37
  exp_name="_mid_ascii_0327_eos_2"
38
  explainer_files=['LTLexplain_0.json','LTLexplain_1.json','LTLexplain_2.json','LTLexplain_3.json']
39
  explainer_dic={}
{finetune → NL2HLTLTranslator}/mistral7b/prediction.py RENAMED
@@ -9,8 +9,8 @@ from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
9
  # from accelerate import infer_auto_device_map,init_empty_weights
10
 
11
  # sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
12
- from NL2HLTLtaskPlanner.utils import Task2Preplacer
13
- from NL2HLTLtaskPlanner.utils import LTLChecker
14
  import re
15
  from datasets import concatenate_datasets
16
  import numpy as np
@@ -22,8 +22,8 @@ os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
22
 
23
  class Mistral_NL2TL_translator():
24
  def __init__(self,
25
- output_dir = "/home/user/xsj/model_weight",
26
- tuned_model_name="mistral7b_mid_ascii_0327_eos_2aug1_quat8",
27
  # CUDA_device='0',
28
  quat=True,
29
  replacer=Task2Preplacer) -> None:
@@ -237,7 +237,7 @@ if __name__=="__main__":
237
 
238
  # Metric
239
  metric = evaluate.load("rouge")
240
- datapath='/home/user/xsj/NL2TL-dataset/collect2'
241
  tokenized_dataset = load_dataset("json", data_files={"train":os.path.join(datapath,"ltl_eng_train_mid_ascii_gptAuged.jsonl"),"test":os.path.join(datapath,"ltl_eng_test_mid_ascii_gptAuged.jsonl")})
242
  print(tokenized_dataset)
243
  # run predictions
@@ -276,7 +276,7 @@ if __name__=="__main__":
276
  eval_output=np.array([input_sentence,predictions,references]).T
277
  import pandas as pd
278
  eval_output=pd.DataFrame(eval_output)
279
- pd.DataFrame.to_csv(eval_output,"/home/user/xsj/model_weight/mistral7b_mid_ascii_0327_eos_2aug1_quat8"+'/output')
280
  # out llama
281
  # Rogue1: 98.363321%
282
  # rouge2: 95.987820%
 
9
  # from accelerate import infer_auto_device_map,init_empty_weights
10
 
11
  # sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
12
+ from NL2HLTLTranslator.utils.util import Task2Preplacer
13
+ from NL2HLTLTranslator.utils.util import LTLChecker
14
  import re
15
  from datasets import concatenate_datasets
16
  import numpy as np
 
22
 
23
  class Mistral_NL2TL_translator():
24
  def __init__(self,
25
+ output_dir = os.path.join(os.path.dirname(__file__),'../../'),
26
+ tuned_model_name="mistral7b_quat8",
27
  # CUDA_device='0',
28
  quat=True,
29
  replacer=Task2Preplacer) -> None:
 
237
 
238
  # Metric
239
  metric = evaluate.load("rouge")
240
+ datapath='path/to/NL2TL-dataset/collect2'
241
  tokenized_dataset = load_dataset("json", data_files={"train":os.path.join(datapath,"ltl_eng_train_mid_ascii_gptAuged.jsonl"),"test":os.path.join(datapath,"ltl_eng_test_mid_ascii_gptAuged.jsonl")})
242
  print(tokenized_dataset)
243
  # run predictions
 
276
  eval_output=np.array([input_sentence,predictions,references]).T
277
  import pandas as pd
278
  eval_output=pd.DataFrame(eval_output)
279
+ pd.DataFrame.to_csv(eval_output,"path/to/model_weight/mistral7b_mid_ascii_0327_eos_2aug1_quat8"+'/output')
280
  # out llama
281
  # Rogue1: 98.363321%
282
  # rouge2: 95.987820%
{finetune → NL2HLTLTranslator}/mistral7b/test.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/realtime_run.py RENAMED
File without changes
{finetune → NL2HLTLTranslator}/test.py RENAMED
File without changes
NL2HLTLTranslator/utils/util.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import sys,os
4
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
5
+
6
+
7
+ def splitJSONfromTXT(txt:str=""):
8
+ jsons=re.findall(r"```json([\s\S]*)```",txt)
9
+ if len(jsons)==0:
10
+ jsons=re.findall(r"```JSON([\s\S]*)```",txt)
11
+ if len(jsons)==0:
12
+ jsons+=re.findall(r"```([\s\S]*)```",txt)
13
+ if len(jsons)==0:
14
+ jsons+=re.findall(r"({[\s\S]*})",txt)
15
+ # print(re.findall(r"```json([\s\S]+?)```",txt))
16
+ print('find {} JSON para\n'.format(len(jsons)))
17
+
18
+ for i in range(len(jsons)):
19
+ # jsons[i]=jsons[i].replace(" ","")
20
+ # # 不能这样,会全部都没有空格
21
+ jsons[i]=jsons[i].replace(" ","")
22
+ jsons[i]=jsons[i].replace("\t","")
23
+ # jsons[i]=jsons[i].replace('_','')
24
+ # jsons[i]=jsons[i].replace('.','')
25
+ # 这是为了消除回车后的制表符
26
+ while(jsons[i][0]==" "):
27
+ jsons[i]=jsons[i][1:]
28
+ while(jsons[i][-1]==" "):
29
+ jsons[i]=jsons[i][:-1]
30
+ jsons[i]=jsons[i].replace("\n","")
31
+ left_brace=jsons[i].count('{')
32
+ right_brace=jsons[i].count('}')
33
+
34
+ if right_brace<left_brace:
35
+ for i in range(left_brace-right_brace):
36
+ jsons[i]+='}'
37
+ if right_brace>left_brace:
38
+ for i in range(right_brace-left_brace):
39
+ jsons[i]=jsons[i][1:]
40
+ # print('jsons[i][0]',jsons[i][0])
41
+ if jsons[i][0]!='{':
42
+ jsons[i]='{'+jsons[0]
43
+ jsons[i]+='}'
44
+ # print(jsons[i])
45
+ # id0=0
46
+ # for id0 in range(len(jsons[i])):
47
+ # if jsons[i][id0]=='"':
48
+ # print("**no more curly brace in return\n**")
49
+ # break
50
+ # elif jsons[i][id0]=='{':
51
+ # jsons[i]=jsons[i][id0+1:]
52
+ # id=len(jsons[i])
53
+ # for id in range(len(jsons[i])-1,0,-1):
54
+ # if jsons[i][id]=='}':
55
+ # break
56
+ # jsons[i]=jsons[i][:id]
57
+ # print("**delete extra curly brace in return\n**")
58
+ # break
59
+ # id=len(jsons[i])
60
+ # for id in range(len(jsons[i])-1,0,-1):
61
+ # if jsons[i][id]=='}':
62
+ # jsons[i]=jsons[i][:id+1]
63
+ # print("**delete extra right curly brace in return\n**")
64
+ # break
65
+ # print("[]",j,"[]")
66
+ # print("splitJSONfromTXT",jsons)
67
+ # print(jsons[0])
68
+ # print(json.loads(jsons[0]))
69
+ return jsons
70
+ pass
71
+ def readMultipleLinesTillException(prompt:str="",log=False,exp_PATH=""):
72
+ # the last line would not be read in
73
+ print("--"*15)
74
+ lines=[" "]
75
+ try:
76
+ lines.append(input(prompt))
77
+ while True:
78
+ lines.append(input())
79
+ except:
80
+ pass
81
+ ret ="".join(lines)
82
+ if log:
83
+ with open(exp_PATH+'log.txt','a') as f: #设置文件对象
84
+ f.write("\nINPUT:\n")
85
+ f.write(prompt) #将字符串写入文件中
86
+ f.write("\nOUTPUT:\n")
87
+ f.write(ret)
88
+ return ret
89
+ # class GPTinterface
90
+ # def GPTinterface(prompts:list,log=False,exp_PATH=""):
91
+
92
+ # pass
93
+
94
+ # # def
95
+ # class GPTTranslater():
96
+ # def __init__(self) -> None:
97
+ # pass
98
+ # def translate(self,prompt:str="",log=False,exp_PATH=""):
99
+ # return GPTinterface(prompt=prompt,log=log,exp_PATH=exp_PATH)
100
+
101
+
102
+ import re
103
+ class Task2Preplacer():
104
+ def __init__(self,input:str=""):
105
+ # self.input=input
106
+ self.Task2PDict={}
107
+ self.P2TaskDict={}
108
+ self.count=0
109
+ pass
110
+ def mapping1(self,name):
111
+ name=name[0]
112
+ while not (name[-1]<='9' and name[-1]>='0'):
113
+ name=name[:-1]
114
+ if not (name in self.Task2PDict):
115
+ self.count+=1
116
+ self.Task2PDict[name]="P{:0>2d}".format(self.count)
117
+ self.P2TaskDict["P{:0>2d}".format(self.count)]=name
118
+ # print('141',self.Task2PDict,self.P2TaskDict)
119
+ return self.Task2PDict[name]
120
+ pass
121
+ def mapping2(self,name):
122
+ name=name[0]
123
+ return "{}".format(self.P2TaskDict[name])
124
+ pass
125
+ def reTask2P(self,Taskinput:str=""):
126
+ self.Task2PDict={}
127
+ self.P2TaskDict={}
128
+ self.count=0
129
+ self.Poutput=re.sub("Task[_0-9\.]+",self.mapping1,Taskinput)
130
+ print('self.Poutput',self.Poutput,'\n')
131
+ return self.Poutput
132
+ def reP2Task(self,Pinput:str=""):
133
+ print("self.P2TaskDict",self.P2TaskDict,'\n')
134
+ print("pinput",Pinput,'\n')
135
+
136
+ self.Taskoutput=re.sub("P[0-9]{2}",self.mapping2,Pinput)
137
+ return self.Taskoutput
138
+
139
+ class Func2Preplacer():
140
+ def __init__(self,functionlist,input:str=""):
141
+ # self.input=input
142
+ self.functionlist=functionlist
143
+ self.Func2Tasklist=[]
144
+ pass
145
+ def reFunc2P(self,Taskinput:str=""):
146
+ def mapping(name):
147
+ name=name[0]
148
+ if not name in self.Task2Plist:
149
+ self.Task2Plist+=[name]
150
+ return "P{:0>2d}".format(self.Task2Plist.index(name))
151
+ pass
152
+ for i,func in enumerate(self.functionlist):
153
+ Taskinput=re.sub(func[0][:-2]+"\([\S]+]\)",mapping,Taskinput)
154
+ # Taskinput=re.sub(func[0][:-2]+"\([\s\S]+]\)",mapping,Taskinput)
155
+ # may be this is the proper one
156
+ self.Poutput=Taskinput
157
+ return self.Poutput
158
+ def reP2Func(self,Pinput:str=""):
159
+ def mapping(name):
160
+ name=name[0]
161
+ return "{}".format(self.Task2Plist[int(name[1:])])
162
+ pass
163
+ self.Taskoutput=re.sub("P[0-9]{2}",mapping,Pinput)
164
+ return self.Taskoutput
165
+
166
+ class getFunc2HierarchicalAP():
167
+ def __init__(self,functionlist,input:str=""):
168
+ # self.input=input
169
+ self.functionlist=functionlist
170
+ self.HierarchicalEleList=dict()
171
+ # a dict to save the AP function and AP name
172
+ # a dict to save the CP Task and CP name
173
+ self.CP_count=0
174
+ self.CP_start=100
175
+ self.AP_count=0
176
+ self.AP_start=10
177
+ def AsciiLTL2FormalLTL(self,LTLInput:str,task:str,LTLtype:str="CP"):
178
+ # LTLtype=CP/AP means the element in the LTL
179
+ if LTLtype in ["CP","cp"]:
180
+ return self.reTask2HierarchicalLTL(LTLInput)
181
+ pass
182
+ elif LTLtype in ["ap","AP"]:
183
+ return self.reFunc2HierarchicalLTL(LTLInput)
184
+ elif LTLtype in ["sp","SP"]:
185
+ self.reTask2HierarchicalLTL(task)
186
+ return LTLInput
187
+ pass
188
+ pass
189
+ def mapAP2Function(self,AP:int):
190
+ # AP is the int number of the task
191
+ if AP in self.HierarchicalEleList:
192
+ return self.HierarchicalEleList[AP]
193
+ else:
194
+ return False
195
+ def mapTask2CP(self,TaskName:str):
196
+ return "p{:0>3d}".format(self.HierarchicalEleList[TaskName])
197
+ def mapping1(self,name):
198
+ name=name[0]
199
+ # print('name ',name, name in self.HierarchicalEleList)
200
+ print('mapping',name)
201
+ if not (name in self.HierarchicalEleList):
202
+ self.AP_count+=1
203
+ self.HierarchicalEleList[self.AP_count+self.AP_start]=name
204
+ self.HierarchicalEleList[name]=self.AP_count+self.AP_start
205
+ return "p{:0>2d}".format(self.HierarchicalEleList[name])
206
+ pass
207
+
208
+ def reFunc2HierarchicalLTL(self,Taskinput:str=""):
209
+ for i,func in enumerate(self.functionlist):
210
+ # print('taskinput ',Taskinput)
211
+ # Taskinput=re.sub(func[0][:-2]+"\([\S]+]\)",mapping,Taskinput)
212
+ Taskinput=re.sub(func[0][:-2]+"\([^\)]+\)",self.mapping1,Taskinput)
213
+ self.Poutput=Taskinput
214
+ return self.Poutput
215
+ # def reTask2Func(self,Pinput:str=""):
216
+ # def mapping(name):
217
+ # return "{}".format(self.Task2Plist[int(name[1:])])
218
+ # pass
219
+ # self.Taskoutput=re.sub("P[0-9]{2}",mapping,Pinput)
220
+ # return self.Taskoutput
221
+ def mapping2(self,name):
222
+ name=name[0]
223
+ while not (name[-1]<='9' and name[-1]>='0'):
224
+ name=name[:-1]
225
+ print('mapping2',name)
226
+
227
+ if not name in self.HierarchicalEleList:
228
+ self.CP_count+=1
229
+ self.HierarchicalEleList[self.CP_count+self.CP_start]=name
230
+ self.HierarchicalEleList[name]=self.CP_count+self.CP_start
231
+ return "p{:0>3d}".format(self.HierarchicalEleList[name])
232
+ pass
233
+ def reTask2HierarchicalLTL(self,Taskinput:str=""):
234
+
235
+ self.Poutput=re.sub("Task_[0-9\.]+",self.mapping2,Taskinput)
236
+ return self.Poutput
237
+ # def reP2Task(self,Pinput:str=""):
238
+ # print("self.Task2Plist",self.Task2Plist,'\n')
239
+ # print("pinput",Pinput,'\n')
240
+ # def mapping(name):
241
+ # return "{}".format(self.Task2Plist[int(name[1:])])
242
+ # pass
243
+ # self.Taskoutput=re.sub("P[0-9]{2}",mapping,Pinput)
244
+ # return self.Taskoutput
245
+
246
+ # Task_1.1 and Task_1.2 can occur independently and either may be executed without affecting the other."
247
+ class FuncParamExtractor():
248
+ def __init__(self,functionDefine) -> None:
249
+ self.functionlist=functionDefine
250
+ self.ParamPattern=re.compile("[\(,]([^,\)]+)")
251
+ pass
252
+ def extractParam(self,inputFunc:str=''):
253
+ return self.ParamPattern.findall(inputFunc)
254
+ def extractFunc(self,inputFunc:str=''):
255
+ for i,func in enumerate(self.functionlist):
256
+ Taskinput=re.search(func[0][:-2],inputFunc)
257
+ if Taskinput:
258
+ break
259
+ # Taskinput=re.sub(func[0][:-2]+"\([\S]+]\)",mapping,Taskinput)
260
+ # Taskinput=re.sub(func[0][:-2]+"\([\s\S]+]\)",mapping,Taskinput)
261
+ # may be this is the proper one
262
+ return Taskinput.group()
263
+ class LTLChecker():
264
+ def __init__(self,APpattern=re.compile("(P[0-9]{2})")) -> None:
265
+ self.APpattern=APpattern
266
+ def AP_CorrCheck(self,natural:str="",ltl:str=""):
267
+ natural_AP=self.APpattern.findall(natural)
268
+ ltl_AP=self.APpattern.findall(ltl)
269
+ # natural_AP.sort()
270
+ # ltl_AP.sort()
271
+ return set(natural_AP)==set(ltl_AP)
272
+ def right_barkets_remover(self,ltl:str):
273
+ ltl=ltl.strip()
274
+ while ltl.count("(")<ltl.count(")"):
275
+ if ltl[-1]==')':
276
+ ltl=ltl[:-1].strip()
277
+ else:
278
+ break
279
+ return ltl
280
+ def brackets_Check(self,ltl:str):
281
+ return ltl.count("(")==ltl.count(")")
282
+
283
+ def reAsciiLTL2EngLTL(AsciiInput:str=''):
284
+ # I > means ->
285
+ # E ^ means <>
286
+ # A means and
287
+ # N means ! ~
288
+ # O means |
289
+ AsciiInput=re.sub('A', 'And', AsciiInput, count=0, flags=0)
290
+ AsciiInput=re.sub('O', 'Or', AsciiInput, count=0, flags=0)
291
+ AsciiInput=re.sub('I', 'Imply', AsciiInput, count=0, flags=0)
292
+ AsciiInput=re.sub('N', 'Not', AsciiInput, count=0, flags=0)
293
+ AsciiInput=re.sub('E', 'Equally', AsciiInput, count=0, flags=0)
294
+ AsciiInput=re.sub('F', 'Finally', AsciiInput, count=0, flags=0)
295
+ AsciiInput=re.sub('G', 'Globally', AsciiInput, count=0, flags=0)
296
+ AsciiInput=re.sub('U', 'Until', AsciiInput, count=0, flags=0)
297
+ AsciiInput=re.sub('X', 'Next', AsciiInput, count=0, flags=0)
298
+ return AsciiInput
299
+ def reEngLTL2FormalLTL(ENGInput:str=''):
300
+ if not isinstance(ENGInput,str):
301
+ return ENGInput
302
+ ENGInput=re.sub('And','&&', ENGInput, count=0, flags=0)
303
+ ENGInput=re.sub('Or','||', ENGInput, count=0, flags=0)
304
+ ENGInput=re.sub('Imply','->', ENGInput, count=0, flags=0)
305
+ ENGInput=re.sub('Not','!', ENGInput, count=0, flags=0)
306
+ ENGInput=re.sub('Equally','<=>', ENGInput, count=0, flags=0)
307
+ ENGInput=re.sub('Finally','<>', ENGInput, count=0, flags=0)
308
+ ENGInput=re.sub('Globally','[]', ENGInput, count=0, flags=0)
309
+ ENGInput=re.sub('Until','U', ENGInput, count=0, flags=0)
310
+ ENGInput=re.sub('Next','X', ENGInput, count=0, flags=0)
311
+ return ENGInput
312
+ def reAsciiLTL2FormalLTL(AsciiInput:str=''):
313
+ AsciiInput=re.sub('A','&&', AsciiInput, count=0, flags=0)
314
+ AsciiInput=re.sub('O','||', AsciiInput, count=0, flags=0)
315
+ AsciiInput=re.sub('I','->', AsciiInput, count=0, flags=0)
316
+ AsciiInput=re.sub('N','!', AsciiInput, count=0, flags=0)
317
+ AsciiInput=re.sub('E','<=>', AsciiInput, count=0, flags=0)
318
+ AsciiInput=re.sub('F','<>', AsciiInput, count=0, flags=0)
319
+ AsciiInput=re.sub('G','[]', AsciiInput, count=0, flags=0)
320
+ AsciiInput=re.sub('U','U', AsciiInput, count=0, flags=0)
321
+ AsciiInput=re.sub('X','X', AsciiInput, count=0, flags=0)
322
+ return AsciiInput
323
+ if __name__=="__main__":
324
+ # print(getFunc2HierarchicalAP([
325
+ # ["Move_raw_ingredient_to_utensile()","name of ingredient", "name of utensil"],
326
+ # ["Move_utensil_to_certain_area()","name of utensil", "area"],
327
+ # ["Move_processed_ingredient_to_utensile()","name of ingredient", "name of utensil"],
328
+ # ["Processing_ingredient()","name of ingredient", "blue_knife/yellow_knife/hand"],
329
+ # ]).reFunc2HierarchicalLTL('Move_raw_ingredient_to_utensile(asd,asd)'))
330
+ # # print(FuncParamExtractor().extractFunc('Processing_ingredient(123,243,dfsa,3,)'))
331
+ # exit()
332
+ print(splitJSONfromTXT("""```json
333
+ {
334
+ "Task_1":{
335
+ "task_id":"Task_1",
336
+ "task_instruction":"Prepare dishes by arranging fruit and vegetables and preparing eggs and meats.",
337
+ "task_relied_description":"We have several dishes to create, and there's a specific order for preparation.",
338
+ "sibling_nodes_condition":"",
339
+ "subtasks_of_this_node":["Task_1.1", "Task_1.2"]
340
+ },
341
+ "Task_1.1":{
342
+ "task_id":"Task_1.1",
343
+ "task_instruction":"Arrange fruit on the yellow plate then place vegetables on the blue plate.",
344
+ "task_relied_description":"Start by arranging the fruit on the yellow plate, followed by the vegetables on the blue plate.",
345
+ "sibling_nodes_condition":"",
346
+ "subtasks_of_this_node":["Task_1.1.1", "Task_1.1.2"]
347
+ },
348
+ "Task_1.1.1":{
349
+ "task_id":"Task_1.1.1",
350
+ "task_instruction":"Put half of the sliced tomato and the sliced watermelon on the yellow plate.",
351
+ "task_relied_description":"We have a tomato and a watermelon to serve as fruit.",
352
+ "sibling_nodes_condition":"",
353
+ "subtasks_of_this_node":["Task_1.1.1.1", "Task_1.1.1.2"]
354
+ },
355
+ "Task_1.1.1.1":{
356
+ "task_id":"Task_1.1.1.1",
357
+ "task_instruction":"Place tomato and watermelon on the yellow cutting board and slice each in half.",
358
+ "task_relied_description":"These can go on the yellow cutting board in any sequence. Slice each in half.",
359
+ "sibling_nodes_condition":"",
360
+ "subtasks_of_this_node":[]
361
+ },
362
+ "Task_1.1.1.2":{
363
+ "task_id":"Task_1.1.1.2",
364
+ "task_instruction":"Move one half of the tomato and the watermelon to the yellow plate.",
365
+ "task_relied_description":"then transfer one half of the tomato and the halved watermelon to the yellow plate.",
366
+ "sibling_nodes_condition":"After slicing the fruits",
367
+ "subtasks_of_this_node":[]
368
+ },
369
+ "Task_1.1.2":{
370
+ "task_id":"Task_1.1.2",
371
+ "task_instruction":"Place vegetables on the blue plate, including dividing and placing the broccoli and separating egg yolk.",
372
+ "task_relied_description":"For the vegetables, initially place the ingredients on the cutting board for slicing before transferring them to a plate.",
373
+ "sibling_nodes_condition":"After placing the fruits on the yellow plate",
374
+ "subtasks_of_this_node":["Task_1.1.2.1", "Task_1.1.2.2"]
375
+ },
376
+ "Task_1.1.2.1":{
377
+ "task_id":"Task_1.1.2.1",
378
+ "task_instruction":"Place broccoli on blue cutting board, divide it into two pieces, and then move both pieces to the blue plate.",
379
+ "task_relied_description":"For the broccoli, lay it on the blue cutting board and divide it into two pieces. Afterwards, move both pieces to the blue plate.",
380
+ "sibling_nodes_condition":"",
381
+ "subtasks_of_this_node":[]
382
+ },
383
+ "Task_1.1.2.2":{
384
+ "task_id":"Task_1.1.2.2",
385
+ "task_instruction":"Take an egg, remove its shell, separate the yolk from the egg and put it directly into the blue plate.",
386
+ "task_relied_description":"When it comes to the egg yolk, first take an egg but don't place it on the cutting board, Directly separate the yolk into the blue plate after removing the shell.",
387
+ "sibling_nodes_condition":"After placing the broccoli",
388
+ "subtasks_of_this_node":[]
389
+ },
390
+ "Task_1.2":{
391
+ "task_id":"Task_1.2",
392
+ "task_instruction":"Prepare the eggs and meats.",
393
+ "task_relied_description":"eggs and meats should be prepped last.",
394
+ "sibling_nodes_condition":"After arranging the fruits and vegetables",
395
+ "subtasks_of_this_node":[]
396
+ }
397
+ }
398
+ ```"""))
399
+ print(reAsciiLTL2EngLTL("F ( P11 A ( F P04 ) )"))
400
+ print(splitJSONfromTXT("""
401
+ {
402
+ "Task_1": {
403
+ "task_id": "Task_1",
404
+ "task_instruction": "Prepare and cut banana, apple, onion and pepper. Follow rule: Do not cut fruits on a cutting board that has been used for cutting vegetables.",
405
+ "sibling_nodes_condition": "NA",
406
+ "subtasks_of_this_node":["Task_1.1", "Task_1.2", "Task_1.3"]
407
+ },
408
+ "Task_1.1": {
409
+ "task_id": "Task_1.1",
410
+ "task_instruction": "Prepare and cut banana and apple. Do not use a cutting board that has been used for cutting vegetables.",
411
+ "sibling_nodes_condition": "NA",
412
+ "subtasks_of_this_node":["Task_1.1.1", "Task_1.1.2"]
413
+ },
414
+ "Task_1.1.1": {
415
+ "task_id": "Task_1.1.1",
416
+ "task_instruction": "Prepare and cut banana",
417
+ "sibling_nodes_condition": "NA",
418
+ "subtasks_of_this_node": []
419
+ },
420
+ "Task_1.1.2": {
421
+ "task_id": "Task_1.1.2",
422
+ "task_instruction": "Prepare and cut apple",
423
+ "sibling_nodes_condition": "After preparing and cutting banana",
424
+ "subtasks_of_this_node": []
425
+ },
426
+ "Task_1.2": {
427
+ "task_id": "Task_1.2",
428
+ "task_instruction": "Prepare and cut onion and pepper",
429
+ "sibling_nodes_condition": "Do not use the same cutting board as used for fruits",
430
+ "subtasks_of_this_node": ["Task_1.2.1", "Task_1.2.2"]
431
+ },
432
+ "Task_1.2.1": {
433
+ "task_id": "Task_1.2.1",
434
+ "task_instruction": "Prepare and cut onion",
435
+ "sibling_nodes_condition": "NA",
436
+ "subtasks_of_this_node": []
437
+ },
438
+ "Task_1.2.2": {
439
+ "task_id": "Task_1.2.2",
440
+ "task_instruction": "Prepare and cut pepper",
441
+ "sibling_nodes_condition": "After preparing and cutting onion",
442
+ "subtasks_of_this_node": []
443
+ },
444
+ "Task_1.3": {
445
+ "task_id": "Task_1.3",
446
+ "task_instruction": "Rule: Do not cut fruits on a cutting board that has been used for cutting vegetables.",
447
+ "sibling_nodes_condition": "NA",
448
+ "subtasks_of_this_node": []
449
+ }}"""))
NL2TL-dataset/collect2/getUniqueLTL.py CHANGED
@@ -12,9 +12,9 @@ def findUniqueLTL(paths:list):
12
  return ret
13
 
14
  if __name__=='__main__':
15
- path=['/home/user/xsj/NL2TL-dataset/collect2/ltl_eng_test_mid_ascii_gptAuged.jsonl','/home/user/xsj/NL2TL-dataset/collect2/ltl_eng_train_mid_ascii_gptAuged.jsonl']
16
  LTLs=findUniqueLTL(paths=path)
17
- with open(os.path.join('/home/user/xsj/NL2TL-dataset/collect2','NLTLsummary.json'),'w') as f :
18
  f.write(json.dumps(LTLs,sort_keys=False,indent=4,separators=(',',':')))
19
 
20
 
 
12
  return ret
13
 
14
  if __name__=='__main__':
15
+ path=['path/to/NL2TL-dataset/collect2/ltl_eng_test_mid_ascii_gptAuged.jsonl','path/to/NL2TL-dataset/collect2/ltl_eng_train_mid_ascii_gptAuged.jsonl']
16
  LTLs=findUniqueLTL(paths=path)
17
+ with open(os.path.join('path/to/NL2TL-dataset/collect2','NLTLsummary.json'),'w') as f :
18
  f.write(json.dumps(LTLs,sort_keys=False,indent=4,separators=(',',':')))
19
 
20
 
README.md CHANGED
@@ -30,7 +30,19 @@ Based task related NL2TL datasets:
30
  - [Lang2LTL](https://github.com/h2r/Lang2LTL)
31
  - [nl2spec](https://github.com/realChrisHahn2/nl2spec)
32
  - [NL2TL](https://github.com/yongchao98/NL2TL)
33
-
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## Cite
35
  ```bibtex
36
  @misc{xu2024scalingnaturallanguageunderstanding,
 
30
  - [Lang2LTL](https://github.com/h2r/Lang2LTL)
31
  - [nl2spec](https://github.com/realChrisHahn2/nl2spec)
32
  - [NL2TL](https://github.com/yongchao98/NL2TL)
33
+ ## File Structure
34
+ - NL2HLTL
35
+ - NL2HLTLTranslator
36
+ - fastapi_server.py a FastAPI server for translate testing, will run on localhost:8001
37
+ - mistral7b
38
+ - finetune.py code for fintune
39
+ - prediction.py code for prediction (this version do not have sockets)
40
+ - mistral7b_quat8 fintuned model based on Mistral7B in quat 8
41
+ - NL2TL-dataset used dataset
42
+ ## Run
43
+ ```bash
44
+ python finetune/fastapi_server.py
45
+ ```
46
  ## Cite
47
  ```bibtex
48
  @misc{xu2024scalingnaturallanguageunderstanding,
finetune/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json DELETED
@@ -1,69 +0,0 @@
1
- {
2
- "builder_name": "json",
3
- "citation": "",
4
- "config_name": "default",
5
- "dataset_name": "json",
6
- "dataset_size": 889411,
7
- "description": "",
8
- "download_checksums": {
9
- "LTL_datasets/collect/ltl_eng_train_mid_ascii_gptAuged.jsonl": {
10
- "num_bytes": 1129386,
11
- "checksum": null
12
- },
13
- "LTL_datasets/collect/ltl_eng_test_mid_ascii_gptAuged.jsonl": {
14
- "num_bytes": 125920,
15
- "checksum": null
16
- }
17
- },
18
- "download_size": 1255306,
19
- "features": {
20
- "id": {
21
- "dtype": "string",
22
- "_type": "Value"
23
- },
24
- "input_ids": {
25
- "feature": {
26
- "dtype": "int32",
27
- "_type": "Value"
28
- },
29
- "_type": "Sequence"
30
- },
31
- "attention_mask": {
32
- "feature": {
33
- "dtype": "int8",
34
- "_type": "Value"
35
- },
36
- "_type": "Sequence"
37
- },
38
- "labels": {
39
- "feature": {
40
- "dtype": "int64",
41
- "_type": "Value"
42
- },
43
- "_type": "Sequence"
44
- }
45
- },
46
- "homepage": "",
47
- "license": "",
48
- "size_in_bytes": 2144717,
49
- "splits": {
50
- "train": {
51
- "name": "train",
52
- "num_bytes": 800102,
53
- "num_examples": 10621,
54
- "dataset_name": "json"
55
- },
56
- "test": {
57
- "name": "test",
58
- "num_bytes": 89309,
59
- "num_examples": 1181,
60
- "dataset_name": "json"
61
- }
62
- },
63
- "version": {
64
- "version_str": "0.0.0",
65
- "major": 0,
66
- "minor": 0,
67
- "patch": 0
68
- }
69
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune/Llama2_13b/data/eval/tf-ltl_eng_test_mid_ascii_gptAuged/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "c6bf809a7a8f99a6",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": "test"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/dataset_info.json DELETED
@@ -1,69 +0,0 @@
1
- {
2
- "builder_name": "json",
3
- "citation": "",
4
- "config_name": "default",
5
- "dataset_name": "json",
6
- "dataset_size": 889411,
7
- "description": "",
8
- "download_checksums": {
9
- "LTL_datasets/collect/ltl_eng_train_mid_ascii_gptAuged.jsonl": {
10
- "num_bytes": 1129386,
11
- "checksum": null
12
- },
13
- "LTL_datasets/collect/ltl_eng_test_mid_ascii_gptAuged.jsonl": {
14
- "num_bytes": 125920,
15
- "checksum": null
16
- }
17
- },
18
- "download_size": 1255306,
19
- "features": {
20
- "id": {
21
- "dtype": "string",
22
- "_type": "Value"
23
- },
24
- "input_ids": {
25
- "feature": {
26
- "dtype": "int32",
27
- "_type": "Value"
28
- },
29
- "_type": "Sequence"
30
- },
31
- "attention_mask": {
32
- "feature": {
33
- "dtype": "int8",
34
- "_type": "Value"
35
- },
36
- "_type": "Sequence"
37
- },
38
- "labels": {
39
- "feature": {
40
- "dtype": "int64",
41
- "_type": "Value"
42
- },
43
- "_type": "Sequence"
44
- }
45
- },
46
- "homepage": "",
47
- "license": "",
48
- "size_in_bytes": 2144717,
49
- "splits": {
50
- "train": {
51
- "name": "train",
52
- "num_bytes": 800102,
53
- "num_examples": 10621,
54
- "dataset_name": "json"
55
- },
56
- "test": {
57
- "name": "test",
58
- "num_bytes": 89309,
59
- "num_examples": 1181,
60
- "dataset_name": "json"
61
- }
62
- },
63
- "version": {
64
- "version_str": "0.0.0",
65
- "major": 0,
66
- "minor": 0,
67
- "patch": 0
68
- }
69
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune/Llama2_13b/data/train/tf-ltl_eng_test_mid_ascii_gptAuged/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "afb9c85014ff4b4e",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": "train"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ setup(
3
+ name='NL2HLTLTranslator',
4
+ version='0.2',
5
+ author='xsj',
6
+ author_email='[email protected]',
7
+ description='the package is used for multi robot task execution in the aithor env, under the instruction structure of LTL',
8
+ packages=find_packages(),
9
+ install_requires=[],
10
+ license='MIT',
11
+ url='https://github.com/darrrt/NL2HLTL',
12
+ )