Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +5 -4
pipeline.py
CHANGED
|
@@ -68,16 +68,16 @@ def classify_query(query: str) -> str:
|
|
| 68 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 69 |
|
| 70 |
# Function to moderate text using Mistral moderation API (async version)
|
| 71 |
-
|
| 72 |
try:
|
| 73 |
# Use Pydantic AI to validate the text
|
| 74 |
-
|
| 75 |
except Exception as e:
|
| 76 |
print(f"Error validating text: {e}")
|
| 77 |
return "Invalid text format."
|
| 78 |
|
| 79 |
# Call the Mistral moderation API
|
| 80 |
-
response =
|
| 81 |
model="mistral-moderation-latest",
|
| 82 |
inputs=[{"role": "user", "content": query}]
|
| 83 |
)
|
|
@@ -201,7 +201,8 @@ async def run_async_pipeline(query: str) -> str:
|
|
| 201 |
|
| 202 |
# Run the pipeline with the event loop
|
| 203 |
def run_with_chain(query: str) -> str:
|
| 204 |
-
|
|
|
|
| 205 |
|
| 206 |
# Initialize chains here
|
| 207 |
classification_chain = get_classification_chain()
|
|
|
|
| 68 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 69 |
|
| 70 |
# Function to moderate text using Mistral moderation API (async version)
|
| 71 |
+
def moderate_text(query: str) -> str:
|
| 72 |
try:
|
| 73 |
# Use Pydantic AI to validate the text
|
| 74 |
+
pydantic_agent.run_sync(query) # Use sync run for Pydantic validation
|
| 75 |
except Exception as e:
|
| 76 |
print(f"Error validating text: {e}")
|
| 77 |
return "Invalid text format."
|
| 78 |
|
| 79 |
# Call the Mistral moderation API
|
| 80 |
+
response = client.classifiers.moderate_chat(
|
| 81 |
model="mistral-moderation-latest",
|
| 82 |
inputs=[{"role": "user", "content": query}]
|
| 83 |
)
|
|
|
|
| 201 |
|
| 202 |
# Run the pipeline with the event loop
|
| 203 |
def run_with_chain(query: str) -> str:
|
| 204 |
+
loop = asyncio.get_event_loop()
|
| 205 |
+
return loop.run_until_complete(run_async_pipeline(query))
|
| 206 |
|
| 207 |
# Initialize chains here
|
| 208 |
classification_chain = get_classification_chain()
|