File size: 4,700 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import datetime
import logging

from fastapi import HTTPException
from sqlalchemy.orm import Session

from components.dbo.models.llm_prompt import LlmPrompt as LlmPromptSQL
from schemas.llm_prompt import LlmPromptCreateSchema, LlmPromptSchema


logger = logging.getLogger(__name__)


class LlmPromptService:
    """
    Сервис для работы с параметрами LLM.
    """

    def __init__(self, db: Session):
        logger.info("LlmPromptService initializing")
        self.db = db

    
    def create(self, prompt_schema: LlmPromptCreateSchema):
        logger.info("Creating a new prompt")
        with self.db() as session:
            new_prompt: LlmPromptSQL = LlmPromptSQL(**prompt_schema.model_dump())
            session.add(new_prompt)
            session.commit()
            session.refresh(new_prompt)
            
            if(new_prompt.is_default):
                self.set_as_default(new_prompt.id)
                
            return LlmPromptSchema(**new_prompt.to_dict())


    def get_list(self) -> list[LlmPromptSchema]:
        with self.db() as session:
            prompts: list[LlmPromptSQL] = session.query(LlmPromptSQL).all()
            
            return [
                LlmPromptSchema(**prompt.to_dict())
                for prompt in prompts
            ]
    
    def get_by_id(self, id: int) -> LlmPromptSchema:
        with self.db() as session:
            prompt: LlmPromptSQL = session.query(LlmPromptSQL).filter(LlmPromptSQL.id == id).first()

            if not prompt:
                raise HTTPException(
                    status_code=400, detail=f"Item with id {id} not found"
                    )

            return LlmPromptSchema(**prompt.to_dict())


    def get_default(self) -> LlmPromptSchema:
        with self.db() as session:
            prompt: LlmPromptSQL = session.query(LlmPromptSQL).filter(LlmPromptSQL.is_default).first()

            if not prompt:
                # Возвращаем дефолтнейший промпт в случае, если ничего нет. 
                # Неочевидно, но в случае факапа всё работать будет.
                return LlmPromptSchema(
                    is_default=True,
                    text='Ты ассистент. Ты помогаешь мне. Ты следуешь моим инструкциям.',
                    name='fallback',
                    id=0,
                    type="system",
                    date_created=datetime.datetime.now(datetime.timezone.utc)
                )

            return LlmPromptSchema(**prompt.to_dict())


    def set_as_default(self, id: int):
        logger.info(f"Set default prompt: {id}")
        
        with self.db() as session:
            session.query(LlmPromptSQL).filter(LlmPromptSQL.is_default).update({"is_default": False})
            prompt_new: LlmPromptSQL = session.query(LlmPromptSQL).filter(LlmPromptSQL.id == id).first()

            if not prompt_new:
                raise HTTPException(
                    status_code=400, detail=f"Item with id {id} not found"
                )

            prompt_new.is_default = True
            session.commit()


    
    def update(self, id: int, new_prompt: LlmPromptSchema): 
        logger.info("Updating default prompt")
        with self.db() as session:
            prompt: LlmPromptSQL = session.query(LlmPromptSQL).filter(LlmPromptSQL.id == id).first()  

            if not prompt:
                raise HTTPException(
                    status_code=400, detail=f"Item with id {id} not found"
                    )
        
            update_data = new_prompt.model_dump(exclude_unset=True)
            
            for key, value in update_data.items():
                if hasattr(prompt, key):
                    setattr(prompt, key, value)
            
            
            session.commit()
            
            
            if(new_prompt.is_default):
                self.set_as_default(new_prompt.id)
                
            session.refresh(prompt)
            return prompt


    def delete(self, id: int): 
        logger.info("Deleting prompt: {id}")
        with self.db() as session:
            prompt_to_del: LlmPromptSQL = session.query(LlmPromptSQL).get(id)
            prompt_default: LlmPromptSQL = session.query(LlmPromptSQL).filter(LlmPromptSQL.is_default).first()  

            if prompt_to_del.id == prompt_default.id:
                raise HTTPException(
                    status_code=400, detail=f"The default prompt cannot be deleted"
                    )
            
            session.delete(prompt_to_del)
            session.commit()