Ver código fonte

Fix CI tests for Mem0 (#1498)

Dev Khant 1 ano atrás
pai
commit
e9136c1aa0

+ 61 - 4
.github/workflows/ci.yml

@@ -4,17 +4,74 @@ on:
   push:
     branches: [main]
     paths:
+      - 'mem0/**'
+      - 'tests/**'
       - 'embedchain/**'
       - 'embedchain/tests/**'
       - 'embedchain/examples/**'
   pull_request:
     paths:
-      - 'embedchain/embedchain/**'
+      - 'mem0/**'
+      - 'tests/**'
+      - 'embedchain/**'
       - 'embedchain/tests/**'
       - 'embedchain/examples/**'
 
 jobs:
-  build:
+  check_changes:
+    runs-on: ubuntu-latest
+    outputs:
+      mem0_changed: ${{ steps.filter.outputs.mem0 }}
+      embedchain_changed: ${{ steps.filter.outputs.embedchain }}
+    steps:
+    - uses: actions/checkout@v3
+    - uses: dorny/paths-filter@v2
+      id: filter
+      with:
+        filters: |
+          mem0:
+            - 'mem0/**'
+            - 'tests/**'
+          embedchain:
+            - 'embedchain/**'
+            - 'embedchain/tests/**'
+            - 'embedchain/examples/**'
+
+  build_mem0:
+    needs: check_changes
+    if: ${{ needs.check_changes.outputs.mem0_changed == 'true' || (needs.check_changes.outputs.mem0_changed == 'false' && needs.check_changes.outputs.embedchain_changed == 'false') }}
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: ["3.10", "3.11"]
+
+    steps:
+      - uses: actions/checkout@v3
+      - name: Set up Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v4
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Install poetry
+        uses: snok/install-poetry@v1
+        with:
+          version: 1.4.2
+          virtualenvs-create: true
+          virtualenvs-in-project: true
+      - name: Load cached venv
+        id: cached-poetry-dependencies
+        uses: actions/cache@v2
+        with:
+          path: .venv
+          key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
+      - name: Install dependencies
+        run: make install_all
+        if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
+      - name: Run tests and generate coverage report
+        run: make test
+
+  build_embedchain:
+    needs: check_changes
+    if: ${{ needs.check_changes.outputs.embedchain_changed == 'true' || (needs.check_changes.outputs.mem0_changed == 'false' && needs.check_changes.outputs.embedchain_changed == 'false') }}
     runs-on: ubuntu-latest
     strategy:
       matrix:
@@ -37,7 +94,7 @@ jobs:
         uses: actions/cache@v2
         with:
           path: .venv
-          key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
+          key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
       - name: Install dependencies
         run: cd embedchain && make install_all
         if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
@@ -50,4 +107,4 @@ jobs:
         with:
           file: coverage.xml
         env:
-          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

+ 11 - 2
Makefile

@@ -1,12 +1,18 @@
 .PHONY: format sort lint
 
 # Variables
-RUFF_OPTIONS = --line-length 120
 ISORT_OPTIONS = --profile black
+PROJECT_NAME := mem0ai
 
 # Default target
 all: format sort lint
 
+install:
+	poetry install
+
+install_all:
+	poetry install
+
 # Format code with ruff
 format:
 	poetry run ruff check . --fix $(RUFF_OPTIONS)
@@ -17,7 +23,7 @@ sort:
 
 # Lint code with ruff
 lint:
-	poetry run ruff check . $(RUFF_OPTIONS)
+	poetry run ruff .
 
 docs:
 	cd docs && mintlify dev
@@ -30,3 +36,6 @@ publish:
 
 clean:
 	poetry run rm -rf dist
+
+test:
+	poetry run pytest

+ 1 - 1
embedchain/embedchain/llm/aws_bedrock.py

@@ -23,7 +23,7 @@ class AWSBedrockLlm(BaseLlm):
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
                 "The required dependencies for AWSBedrock are not installed."
-                "Please install with `pip install boto3==1.34.20`"
+                "Please install with `pip install boto3==1.34.20`."
             ) from None
 
         self.boto_client = boto3.client("bedrock-runtime", "us-west-2" or os.environ.get("AWS_REGION"))

+ 5 - 1
mem0/llms/aws_bedrock.py

@@ -14,7 +14,11 @@ class AWSBedrockLLM(LLMBase):
         if not self.config.model:
             self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
         self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
-        self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
+        self.model_kwargs = {
+            "temperature": self.config.temperature, 
+            "max_tokens_to_sample": self.config.max_tokens, 
+            "top_p": self.config.top_p
+        }
 
     def _format_messages(self, messages: List[Dict[str, str]]) -> str:
         """