Skip to content

Commit f52cb05

Browse files
authored
implement posts clustering (#441)
* implement posts clustering #421 * use pgvector-enabled pgdb server for testing * implement build endpoints #421 * add missing environment * mock create_embedding in tasks * remove build embedding
1 parent 1e8c35f commit f52cb05

33 files changed

Lines changed: 1860 additions & 28 deletions

.env.example

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,9 @@ HISTORY4FEED_EARLIEST_SEARCH_DATE=
5555
HISTORY4FEED_WAYBACK_SLEEP_SECONDS=
5656
HISTORY4FEED_REQUEST_RETRY_COUNT=
5757
# pdfshift
58-
PDFSHIFT_API_KEY=
58+
PDFSHIFT_API_KEY=
59+
60+
# clustering settings
61+
CLASSIFIER_MIN_CLUSTER_SIZE=
62+
CLASSIFIER_LABEL_SAMPLE_SIZE=
63+
CLASSIFIER_CONCURRENCY=

.env.markdown

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,13 @@ If you're not using a Proxy it is very likely you'll run into rate limits on the
160160
## pdfshift
161161

162162
* `PDFSHIFT_API_KEY`: get from `https://app.pdfshift.io/`
163-
* is used to generate PDFs from posts. If you use the generate PDF setting in profile, this variable must be used.
163+
* is used to generate PDFs from posts. If you use the generate PDF setting in profile, this variable must be used.
164+
165+
166+
## clustering settings
167+
* `CLASSIFIER_MIN_CLUSTER_SIZE`: `5-10`
168+
* This is the minimum number of posts that should be in a cluster. Clusters smaller than this will be discarded. Setting this value too low may result in many small clusters that are not meaningful, while setting this value too high may result in missing out on smaller but still relevant clusters.
169+
* `CLASSIFIER_LABEL_SAMPLE_SIZE`: `5-20`
170+
* This is the number of posts that will be sampled from each cluster to generate a label. Setting this value too low may result in less accurate labels, while setting this value too high may result in increased processing time.
171+
* `CLASSIFIER_CONCURRENCY`: `12`
172+
* This is the number of worker threads to use for concurrent labelling of clusters. Adjust this value according to your system's capabilities and the volume of data being processed.

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM python:3.11-slim
1+
FROM python:3.11
22
ENV PYTHONUNBUFFERED=1
33

44
WORKDIR /usr/src/app

Dockerfile.deploy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM python:3.11-slim
1+
FROM python:3.11
22
ENV PYTHONUNBUFFERED=1
33

44
ARG DJANGO_DEBUG=

Dockerfile.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM python:3.11-slim
1+
FROM python:3.11
22
ENV PYTHONUNBUFFERED=1
33

44
ARG DJANGO_DEBUG=

docker-compose.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ services:
2424
- 8001:8001
2525
depends_on:
2626
- celery
27-
2827
celery:
2928
extends: env_django
3029
command: >
@@ -36,5 +35,13 @@ services:
3635
condition: service_started
3736
env_django:
3837
condition: service_completed_successfully
38+
environment:
39+
- CLASSIFIER_MODEL_PATH=/opt/clusters/classfier_hdbscan.joblib
40+
volumes:
41+
- clusters:/opt/clusters/
42+
3943
redis:
40-
image: "redis:alpine"
44+
image: "redis:alpine"
45+
46+
volumes:
47+
clusters:

obstracts/cjob/tasks.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import logging
33
import uuid
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
45
from celery import shared_task, chain, current_task, Task as CeleryTask
56
from django.db import transaction
67
import typing
@@ -11,6 +12,8 @@
1112
import requests
1213

1314
from obstracts.cjob import helpers
15+
from obstracts.classifier.models import DocumentEmbedding
16+
import obstracts.classifier.tasks as classifier_tasks
1417
from ..server.models import Job
1518
from ..server import models
1619
from django.core.cache import cache
@@ -192,6 +195,99 @@ def update_vulnerabilities(job_id):
192195
state = models.JobState.PROCESS_FAILED
193196
job.update_state(state)
194197

198+
199+
def _build_topic_embedding_for_post(post_id, force=False):
200+
try:
201+
post_file = models.File.objects.select_related("post").get(pk=post_id)
202+
post_file.create_embedding(force=force)
203+
if post_file.embedding_id:
204+
return "processed", None
205+
return "failed", f"embedding not created for post {post_id}"
206+
except Exception:
207+
logging.exception("embedding build failed for post %s", post_id)
208+
return "failed", f"embedding build failed for post {post_id}"
209+
210+
211+
def run_topic_embeddings_job(job_id, force=False):
212+
job = models.Job.objects.get(pk=job_id)
213+
try:
214+
qs = models.File.objects.filter(
215+
processed=True,
216+
ai_describes_incident=True,
217+
)
218+
if not force:
219+
qs = qs.filter(embedding__isnull=True)
220+
221+
post_ids = list(qs.values_list("post_id", flat=True))
222+
if not post_ids:
223+
job.update_state(models.JobState.PROCESSED)
224+
return
225+
226+
cancelled = False
227+
228+
with ThreadPoolExecutor(max_workers=settings.CLASSIFIER_CONCURRENCY) as pool:
229+
futures = {
230+
pool.submit(_build_topic_embedding_for_post, post_id, force): post_id
231+
for post_id in post_ids
232+
}
233+
for future in as_completed(futures):
234+
status, msg = future.result()
235+
if job.is_cancelled():
236+
cancelled = True
237+
pool.shutdown(wait=False, cancel_futures=True)
238+
if status == "processed":
239+
job.processed_items += 1
240+
elif status == "failed":
241+
job.failed_processes += 1
242+
if msg:
243+
job.errors.append(msg)
244+
if cancelled:
245+
job.update_state(models.JobState.CANCELLED)
246+
elif job.failed_processes and job.processed_items == 0:
247+
job.update_state(models.JobState.PROCESS_FAILED)
248+
else:
249+
job.update_state(models.JobState.PROCESSED)
250+
except Exception as e:
251+
logging.exception("topic embedding task failed")
252+
job.failed_processes += 1
253+
job.errors.append(str(e))
254+
job.update_state(models.JobState.PROCESS_FAILED)
255+
finally:
256+
job.save(update_fields=["errors", "processed_items", "failed_processes"])
257+
258+
259+
def run_topic_clusters_job(job_id, force=False):
260+
job = models.Job.objects.get(pk=job_id)
261+
try:
262+
if job.is_cancelled():
263+
job.update_state(models.JobState.CANCELLED)
264+
return
265+
266+
classifier_tasks.run_clustering(
267+
force=force,
268+
workers=settings.CLASSIFIER_CONCURRENCY,
269+
should_cancel=lambda: models.Job.objects.get(pk=job_id).is_cancelled(),
270+
)
271+
if job.is_cancelled():
272+
job.update_state(models.JobState.CANCELLED)
273+
return
274+
job.processed_items += 1
275+
job.update_state(models.JobState.PROCESSED)
276+
except classifier_tasks.ClusteringCancelled:
277+
job.update_state(models.JobState.CANCELLED)
278+
except Exception as e:
279+
logging.exception("topic cluster task failed")
280+
job.failed_processes += 1
281+
job.errors.append(str(e))
282+
job.update_state(models.JobState.PROCESS_FAILED)
283+
finally:
284+
job.save(update_fields=["errors", "processed_items", "failed_processes"])
285+
286+
287+
@shared_task
288+
def build_topic_clusters(job_id, force=False):
289+
run_topic_clusters_job(job_id, force=force)
290+
195291
@shared_task
196292
def add_pdf_to_post(job_id, post_id):
197293
job = models.Job.objects.get(pk=job_id)
@@ -298,6 +394,7 @@ def process_post(self, job_id, post_id, profile_id=None, *args):
298394
)
299395

300396
file.set_txt2stix_data(processor.txt2stix_data)
397+
file.create_embedding()
301398

302399
file.processed = True
303400
file.save(

obstracts/classifier/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
default_app_config = 'obstracts.classifier.apps.ClassifierConfig'

obstracts/classifier/admin.py

Whitespace-only changes.

obstracts/classifier/apps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from django.apps import AppConfig
2+
3+
4+
class ClassifierConfig(AppConfig):
5+
default_auto_field = "django.db.models.BigAutoField"
6+
name = "obstracts.classifier"
7+
verbose_name = "Classifier"

0 commit comments

Comments
 (0)